diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index 9f42130..b539a67 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -9,21 +9,23 @@ include_directories(${RNLLAMA_LIB_DIR}) set( SOURCE_FILES + ${RNLLAMA_LIB_DIR}/llama-grammar.cpp + ${RNLLAMA_LIB_DIR}/llama-sampling.cpp + ${RNLLAMA_LIB_DIR}/llama-vocab.cpp + ${RNLLAMA_LIB_DIR}/log.cpp + + ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/ggml-alloc.c - ${RNLLAMA_LIB_DIR}/ggml-backend.c + ${RNLLAMA_LIB_DIR}/ggml-backend.cpp ${RNLLAMA_LIB_DIR}/ggml.c ${RNLLAMA_LIB_DIR}/ggml-quants.c ${RNLLAMA_LIB_DIR}/common.cpp - ${RNLLAMA_LIB_DIR}/grammar-parser.cpp ${RNLLAMA_LIB_DIR}/json.hpp ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp ${RNLLAMA_LIB_DIR}/sampling.cpp ${RNLLAMA_LIB_DIR}/unicode-data.cpp ${RNLLAMA_LIB_DIR}/unicode.cpp ${RNLLAMA_LIB_DIR}/llama.cpp - ${RNLLAMA_LIB_DIR}/llama-vocab.cpp - ${RNLLAMA_LIB_DIR}/llama-sampling.cpp - ${RNLLAMA_LIB_DIR}/llama-grammar.cpp ${RNLLAMA_LIB_DIR}/sgemm.cpp ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/rn-llama.hpp @@ -65,10 +67,20 @@ build_library("rnllama" "") if (${ANDROID_ABI} STREQUAL "arm64-v8a") # ARM64 targets + build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve") + build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve") + build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm") build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod") build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod") build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16") build_library("rnllama_v8" "-march=armv8-a") + + # https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk + # llama.cpp will deal with the cpu features + # build_library("rnllama_v8_7" "-march=armv8.7-a") + # TODO: Add support runtime check for cpu features + # At the moment runtime check is failing. + elseif (${ANDROID_ABI} STREQUAL "x86_64") # x86_64 target build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt") diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 55d90f3..cb1cb19 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -35,6 +35,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa if (!params.hasKey("model")) { throw new IllegalArgumentException("Missing required parameter: model"); } + Log.d(NAME, "Setting log callback"); + logToAndroid(); this.id = id; this.context = initContext( // String model, @@ -53,6 +55,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true, // boolean use_mmap, params.hasKey("use_mmap") ? params.getBoolean("use_mmap") : true, + //boolean vocab_only, + params.hasKey("vocab_only") ? params.getBoolean("vocab_only") : false, // String lora, params.hasKey("lora") ? params.getString("lora") : "", // float lora_scaled, @@ -181,6 +185,10 @@ public WritableMap completion(ReadableMap params) { params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f, // float min_p, params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f, + // float xtc_threshold, + params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f, + // float xtc_probability, + params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f, // float tfs_z, params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f, // float typical_p, @@ -248,16 +256,34 @@ public void release() { static { Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]); - if (LlamaContext.isArm64V8a()) { - String cpuFeatures = LlamaContext.getCpuFeatures(); - Log.d(NAME, "CPU features: " + cpuFeatures); - - boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp"); - boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp"); - boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes"); - boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat"); - if (isAtLeastArmV84 && hasFp16 && hasDotProd) { + String cpuFeatures = LlamaContext.getCpuFeatures(); + Log.d(NAME, "CPU features: " + cpuFeatures); + boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp"); + boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp"); + boolean hasSve = cpuFeatures.contains("sve"); + boolean hasI8mm = cpuFeatures.contains("i8mm"); + boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes"); + boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat"); + Log.d(NAME, "- hasFp16: " + hasFp16); + Log.d(NAME, "- hasDotProd: " + hasDotProd); + Log.d(NAME, "- hasSve: " + hasSve); + Log.d(NAME, "- hasI8mm: " + hasI8mm); + Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82); + Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84); + + // TODO: Add runtime check for cpu features + if (LlamaContext.isArm64V8a()) { + if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) { + Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so"); + System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve"); + } else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) { + Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so"); + System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve"); + } else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) { + Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so"); + System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm"); + } else if (isAtLeastArmV84 && hasFp16 && hasDotProd) { Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so"); System.loadLibrary("rnllama_v8_4_fp16_dotprod"); } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) { @@ -270,14 +296,16 @@ public void release() { Log.d(NAME, "Loading librnllama_v8.so"); System.loadLibrary("rnllama_v8"); } + // Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection"); + // System.loadLibrary("rnllama_v8_7"); } else if (LlamaContext.isX86_64()) { - Log.d(NAME, "Loading librnllama_x86_64.so"); - System.loadLibrary("rnllama_x86_64"); + Log.d(NAME, "Loading librnllama_x86_64.so"); + System.loadLibrary("rnllama_x86_64"); } else { - Log.d(NAME, "Loading default librnllama.so"); - System.loadLibrary("rnllama"); + Log.d(NAME, "Loading default librnllama.so"); + System.loadLibrary("rnllama"); } - } +} private static boolean isArm64V8a() { return Build.SUPPORTED_ABIS[0].equals("arm64-v8a"); @@ -316,6 +344,7 @@ protected static native long initContext( int n_gpu_layers, // TODO: Support this boolean use_mlock, boolean use_mmap, + boolean vocab_only, String lora, float lora_scaled, float rope_freq_base, @@ -357,6 +386,8 @@ protected static native WritableMap doCompletion( int top_k, float top_p, float min_p, + float xtc_threshold, + float xtc_probability, float tfs_z, float typical_p, int seed, @@ -373,4 +404,5 @@ protected static native WritableMap doCompletion( protected static native WritableMap embedding(long contextPtr, String text); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); protected static native void freeContext(long contextPtr); + protected static native void logToAndroid(); } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 5377aea..4ad058b 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -3,6 +3,7 @@ // #include #include #include +#include #include #include #include @@ -21,6 +22,13 @@ static inline int min(int a, int b) { return (a < b) ? a : b; } +static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) { + if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); + else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); + else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); + else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); +} + extern "C" { // Method to create WritableMap @@ -139,6 +147,7 @@ Java_com_rnllama_LlamaContext_initContext( jint n_gpu_layers, // TODO: Support this jboolean use_mlock, jboolean use_mmap, + jboolean vocab_only, jstring lora_str, jfloat lora_scaled, jfloat rope_freq_base, @@ -146,7 +155,12 @@ Java_com_rnllama_LlamaContext_initContext( ) { UNUSED(thiz); - gpt_params defaultParams; + common_params defaultParams; + + defaultParams.vocab_only = vocab_only; + if(vocab_only) { + defaultParams.warmup = false; + } const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr); defaultParams.model = model_path_chars; @@ -159,7 +173,7 @@ Java_com_rnllama_LlamaContext_initContext( int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); - defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads; + defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads; defaultParams.n_gpu_layers = n_gpu_layers; @@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat( UNUSED(thiz); auto llama = context_map[(long) context_ptr]; - std::vector chat; + std::vector chat; int messages_len = env->GetArrayLength(messages); for (int i = 0; i < messages_len; i++) { @@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat( } const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr); - std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true); + std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true); return env->NewStringUTF(formatted_chat.c_str()); } @@ -364,7 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion( jint top_k, jfloat top_p, jfloat min_p, - jfloat tfs_z, + jfloat xtc_threshold, + jfloat xtc_probability, jfloat typical_p, jint seed, jobjectArray stop, @@ -377,18 +392,18 @@ Java_com_rnllama_LlamaContext_doCompletion( llama->rewind(); - llama_reset_timings(llama->ctx); + //llama_reset_timings(llama->ctx); llama->params.prompt = env->GetStringUTFChars(prompt, nullptr); - llama->params.seed = seed; + llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed; int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); - llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads; + llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads; llama->params.n_predict = n_predict; - llama->params.ignore_eos = ignore_eos; + llama->params.sparams.ignore_eos = ignore_eos; auto & sparams = llama->params.sparams; sparams.temp = temperature; @@ -403,14 +418,15 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.top_k = top_k; sparams.top_p = top_p; sparams.min_p = min_p; - sparams.tfs_z = tfs_z; - sparams.typical_p = typical_p; + sparams.typ_p = typical_p; sparams.n_probs = n_probs; sparams.grammar = env->GetStringUTFChars(grammar, nullptr); + sparams.xtc_threshold = xtc_threshold; + sparams.xtc_probability = xtc_probability; sparams.logit_bias.clear(); if (ignore_eos) { - sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY; + sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY; } const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx)); @@ -424,9 +440,9 @@ Java_com_rnllama_LlamaContext_doCompletion( llama_token tok = static_cast(doubleArray[0]); if (tok >= 0 && tok < n_vocab) { if (doubleArray[1] != 0) { // If the second element is not false (0) - sparams.logit_bias[tok] = doubleArray[1]; + sparams.logit_bias[tok].bias = doubleArray[1]; } else { - sparams.logit_bias[tok] = -INFINITY; + sparams.logit_bias[tok].bias = -INFINITY; } } @@ -460,7 +476,7 @@ Java_com_rnllama_LlamaContext_doCompletion( if (token_with_probs.tok == -1 || llama->incomplete) { continue; } - const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok); + const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok); size_t pos = std::min(sent_count, llama->generated_text.size()); @@ -495,7 +511,7 @@ Java_com_rnllama_LlamaContext_doCompletion( putString(env, tokenResult, "token", to_send.c_str()); if (llama->params.sparams.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); + const std::vector to_send_toks = common_tokenize(llama->ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); if (probs_pos < probs_stop_pos) { @@ -512,7 +528,7 @@ Java_com_rnllama_LlamaContext_doCompletion( } } - llama_print_timings(llama->ctx); + llama_perf_context_print(llama->ctx); llama->is_predicting = false; auto result = createWriteableMap(env); @@ -527,16 +543,17 @@ Java_com_rnllama_LlamaContext_doCompletion( putString(env, result, "stopping_word", llama->stopping_word.c_str()); putInt(env, result, "tokens_cached", llama->n_past); - const auto timings = llama_get_timings(llama->ctx); + const auto timings_token = llama_perf_context(llama -> ctx); + auto timingsResult = createWriteableMap(env); - putInt(env, timingsResult, "prompt_n", timings.n_p_eval); - putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms); - putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval); - putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval); - putInt(env, timingsResult, "predicted_n", timings.n_eval); - putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms); - putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval); - putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval); + putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval); + putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms); + putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval); + putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval); + putInt(env, timingsResult, "predicted_n", timings_token.n_eval); + putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms); + putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval); + putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval); putMap(env, result, "timings", timingsResult); @@ -569,7 +586,7 @@ Java_com_rnllama_LlamaContext_tokenize( const char *text_chars = env->GetStringUTFChars(text, nullptr); - const std::vector toks = llama_tokenize( + const std::vector toks = common_tokenize( llama->ctx, text_chars, false @@ -623,7 +640,7 @@ Java_com_rnllama_LlamaContext_embedding( llama->rewind(); - llama_reset_timings(llama->ctx); + llama_perf_context_reset(llama->ctx); llama->params.prompt = text_chars; @@ -681,9 +698,16 @@ Java_com_rnllama_LlamaContext_freeContext( } if (llama->ctx_sampling != nullptr) { - llama_sampling_free(llama->ctx_sampling); + common_sampler_free(llama->ctx_sampling); } context_map.erase((long) llama->ctx); } +JNIEXPORT void JNICALL +Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) { + UNUSED(env); + UNUSED(thiz); + llama_log_set(log_callback, NULL); +} + } // extern "C" diff --git a/cpp/common.cpp b/cpp/common.cpp index 94cb0f1..347dd0f 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -3,6 +3,7 @@ #endif #include "common.h" +#include "log.h" // Change JSON_ASSERT from assert() to LM_GGML_ASSERT: #define JSON_ASSERT LM_GGML_ASSERT #include "json.hpp" @@ -11,6 +12,7 @@ #include #include +#include #include #include #include @@ -22,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -48,7 +51,6 @@ #if defined(LLAMA_USE_CURL) #include #include -#include #include #endif @@ -62,14 +64,6 @@ char const *LLAMA_BUILD_TARGET = "unknown"; #pragma warning(disable: 4244 4267) // possible loss of data #endif -#if (defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_SYCL)) -#define LM_GGML_USE_CUDA_SYCL -#endif - -#if (defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_SYCL)) || defined(LM_GGML_USE_VULKAN) -#define LM_GGML_USE_CUDA_SYCL_VULKAN -#endif - #if defined(LLAMA_USE_CURL) #ifdef __linux__ #include @@ -83,41 +77,6 @@ char const *LLAMA_BUILD_TARGET = "unknown"; using json = nlohmann::ordered_json; -// -// Environment variable utils -// - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::string(value) : target; -} - -template -static typename std::enable_if::value && std::is_integral::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::stoi(value) : target; -} - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::stof(value) : target; -} - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - if (value) { - std::string val(value); - target = val == "1" || val == "true"; - } -} - // // CPU utils // @@ -211,1603 +170,227 @@ static bool is_hybrid_cpu(void) { static bool is_running_on_efficiency_core(void) { unsigned eax, ebx, ecx, edx; - cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); - int intel_atom = 0x20; - int core_type = (eax & 0xff000000u) >> 24; - return core_type == intel_atom; -} - -static int cpu_count_math_cpus(int n_cpu) { - int result = 0; - for (int cpu = 0; cpu < n_cpu; ++cpu) { - if (pin_cpu(cpu)) { - return -1; - } - if (is_running_on_efficiency_core()) { - continue; // efficiency cores harm lockstep threading - } - ++cpu; // hyperthreading isn't useful for linear algebra - ++result; - } - return result; -} - -#endif // __x86_64__ && __linux__ - -/** - * Returns number of CPUs on system that are useful for math. - */ -int32_t cpu_get_num_math() { -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) - int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); - if (n_cpu < 1) { - return cpu_get_num_physical_cores(); - } - if (is_hybrid_cpu()) { - cpu_set_t affinity; - if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { - int result = cpu_count_math_cpus(n_cpu); - pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); - if (result > 0) { - return result; - } - } - } -#endif - return cpu_get_num_physical_cores(); -} - -// -// CLI argument parsing -// - -void gpt_params_handle_model_default(gpt_params & params) { - if (!params.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { - params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); - } - } else if (!params.model_url.empty()) { - if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - params.model = fs_get_cache_file(string_split(f, '/').back()); - } - } else if (params.model.empty()) { - params.model = DEFAULT_MODEL_PATH; - } -} - -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { - throw std::invalid_argument("error: unknown argument: " + arg); - } - if (invalid_param) { - throw std::invalid_argument("error: invalid parameter for argument: " + arg); - } - } - - if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { - throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); - } - - gpt_params_handle_model_default(params); - - if (params.hf_token.empty()) { - get_env("HF_TOKEN", params.hf_token); - } - - if (params.escape) { - string_process_escapes(params.prompt); - string_process_escapes(params.input_prefix); - string_process_escapes(params.input_suffix); - string_process_escapes(sparams.cfg_negative_prompt); - for (auto & antiprompt : params.antiprompt) { - string_process_escapes(antiprompt); - } - } - - if (!params.kv_overrides.empty()) { - params.kv_overrides.emplace_back(); - params.kv_overrides.back().key[0] = 0; - } - - return true; -} - -void gpt_params_parse_from_env(gpt_params & params) { - // we only care about server-related params for now - get_env("LLAMA_ARG_MODEL", params.model); - get_env("LLAMA_ARG_THREADS", params.n_threads); - get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx); - get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel); - get_env("LLAMA_ARG_BATCH", params.n_batch); - get_env("LLAMA_ARG_UBATCH", params.n_ubatch); - get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers); - get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http); - get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template); - get_env("LLAMA_ARG_N_PREDICT", params.n_predict); - get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics); - get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots); - get_env("LLAMA_ARG_EMBEDDINGS", params.embedding); - get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn); - get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold); -} - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - const auto params_org = params; // the example can modify the default params - - try { - if (!gpt_params_parse_ex(argc, argv, params) || params.usage) { - params = params_org; - params.usage = true; - return false; - } - } catch (const std::invalid_argument & ex) { - fprintf(stderr, "%s\n", ex.what()); - params = params_org; - return false; - } - - return true; -} - -#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } - -bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { - const char split_delim = ','; - - llama_sampling_params & sparams = params.sparams; - - if (arg == "-s" || arg == "--seed") { - CHECK_ARG - // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. - params.seed = std::stoul(argv[i]); - sparams.seed = std::stoul(argv[i]); - return true; - } - if (arg == "-t" || arg == "--threads") { - CHECK_ARG - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); - } - return true; - } - if (arg == "-tb" || arg == "--threads-batch") { - CHECK_ARG - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); - } - return true; - } - if (arg == "-td" || arg == "--threads-draft") { - CHECK_ARG - params.n_threads_draft = std::stoi(argv[i]); - if (params.n_threads_draft <= 0) { - params.n_threads_draft = std::thread::hardware_concurrency(); - } - return true; - } - if (arg == "-tbd" || arg == "--threads-batch-draft") { - CHECK_ARG - params.n_threads_batch_draft = std::stoi(argv[i]); - if (params.n_threads_batch_draft <= 0) { - params.n_threads_batch_draft = std::thread::hardware_concurrency(); - } - return true; - } - if (arg == "-p" || arg == "--prompt") { - CHECK_ARG - params.prompt = argv[i]; - return true; - } - if (arg == "-e" || arg == "--escape") { - params.escape = true; - return true; - } - if (arg == "--no-escape") { - params.escape = false; - return true; - } - if (arg == "--prompt-cache") { - CHECK_ARG - params.path_prompt_cache = argv[i]; - return true; - } - if (arg == "--prompt-cache-all") { - params.prompt_cache_all = true; - return true; - } - if (arg == "--prompt-cache-ro") { - params.prompt_cache_ro = true; - return true; - } - if (arg == "-bf" || arg == "--binary-file") { - CHECK_ARG - std::ifstream file(argv[i], std::ios::binary); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - // store the external file name in params - params.prompt_file = argv[i]; - std::ostringstream ss; - ss << file.rdbuf(); - params.prompt = ss.str(); - fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); - return true; - } - if (arg == "-f" || arg == "--file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - // store the external file name in params - params.prompt_file = argv[i]; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (!params.prompt.empty() && params.prompt.back() == '\n') { - params.prompt.pop_back(); - } - return true; - } - if (arg == "--in-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - params.in_files.push_back(argv[i]); - return true; - } - if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { - CHECK_ARG - params.n_predict = std::stoi(argv[i]); - return true; - } - if (arg == "--top-k") { - CHECK_ARG - sparams.top_k = std::stoi(argv[i]); - return true; - } - if (arg == "-c" || arg == "--ctx-size") { - CHECK_ARG - params.n_ctx = std::stoi(argv[i]); - return true; - } - if (arg == "--grp-attn-n" || arg == "-gan") { - CHECK_ARG - params.grp_attn_n = std::stoi(argv[i]); - return true; - } - if (arg == "--grp-attn-w" || arg == "-gaw") { - CHECK_ARG - params.grp_attn_w = std::stoi(argv[i]); - return true; - } - if (arg == "--rope-freq-base") { - CHECK_ARG - params.rope_freq_base = std::stof(argv[i]); - return true; - } - if (arg == "--rope-freq-scale") { - CHECK_ARG - params.rope_freq_scale = std::stof(argv[i]); - return true; - } - if (arg == "--rope-scaling") { - CHECK_ARG - std::string value(argv[i]); - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } - else { invalid_param = true; } - return true; - } - if (arg == "--rope-scale") { - CHECK_ARG - params.rope_freq_scale = 1.0f / std::stof(argv[i]); - return true; - } - if (arg == "--yarn-orig-ctx") { - CHECK_ARG - params.yarn_orig_ctx = std::stoi(argv[i]); - return true; - } - if (arg == "--yarn-ext-factor") { - CHECK_ARG - params.yarn_ext_factor = std::stof(argv[i]); - return true; - } - if (arg == "--yarn-attn-factor") { - CHECK_ARG - params.yarn_attn_factor = std::stof(argv[i]); - return true; - } - if (arg == "--yarn-beta-fast") { - CHECK_ARG - params.yarn_beta_fast = std::stof(argv[i]); - return true; - } - if (arg == "--yarn-beta-slow") { - CHECK_ARG - params.yarn_beta_slow = std::stof(argv[i]); - return true; - } - if (arg == "--pooling") { - CHECK_ARG - std::string value(argv[i]); - /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } - else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } - else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } - else { invalid_param = true; } - return true; - } - if (arg == "--attention") { - CHECK_ARG - std::string value(argv[i]); - /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } - else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } - else { invalid_param = true; } - return true; - } - if (arg == "--defrag-thold" || arg == "-dt") { - CHECK_ARG - params.defrag_thold = std::stof(argv[i]); - return true; - } - if (arg == "--samplers") { - CHECK_ARG - const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); - return true; - } - if (arg == "--sampling-seq") { - CHECK_ARG - sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); - return true; - } - if (arg == "--top-p") { - CHECK_ARG - sparams.top_p = std::stof(argv[i]); - return true; - } - if (arg == "--min-p") { - CHECK_ARG - sparams.min_p = std::stof(argv[i]); - return true; - } - if (arg == "--temp") { - CHECK_ARG - sparams.temp = std::stof(argv[i]); - sparams.temp = std::max(sparams.temp, 0.0f); - return true; - } - if (arg == "--tfs") { - CHECK_ARG - sparams.tfs_z = std::stof(argv[i]); - return true; - } - if (arg == "--typical") { - CHECK_ARG - sparams.typical_p = std::stof(argv[i]); - return true; - } - if (arg == "--repeat-last-n") { - CHECK_ARG - sparams.penalty_last_n = std::stoi(argv[i]); - sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); - return true; - } - if (arg == "--repeat-penalty") { - CHECK_ARG - sparams.penalty_repeat = std::stof(argv[i]); - return true; - } - if (arg == "--frequency-penalty") { - CHECK_ARG - sparams.penalty_freq = std::stof(argv[i]); - return true; - } - if (arg == "--presence-penalty") { - CHECK_ARG - sparams.penalty_present = std::stof(argv[i]); - return true; - } - if (arg == "--dynatemp-range") { - CHECK_ARG - sparams.dynatemp_range = std::stof(argv[i]); - return true; - } - if (arg == "--dynatemp-exp") { - CHECK_ARG - sparams.dynatemp_exponent = std::stof(argv[i]); - return true; - } - if (arg == "--mirostat") { - CHECK_ARG - sparams.mirostat = std::stoi(argv[i]); - return true; - } - if (arg == "--mirostat-lr") { - CHECK_ARG - sparams.mirostat_eta = std::stof(argv[i]); - return true; - } - if (arg == "--mirostat-ent") { - CHECK_ARG - sparams.mirostat_tau = std::stof(argv[i]); - return true; - } - if (arg == "--cfg-negative-prompt") { - CHECK_ARG - sparams.cfg_negative_prompt = argv[i]; - return true; - } - if (arg == "--cfg-negative-prompt-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); - if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { - sparams.cfg_negative_prompt.pop_back(); - } - return true; - } - if (arg == "--cfg-scale") { - CHECK_ARG - sparams.cfg_scale = std::stof(argv[i]); - return true; - } - if (arg == "-b" || arg == "--batch-size") { - CHECK_ARG - params.n_batch = std::stoi(argv[i]); - return true; - } - if (arg == "-ub" || arg == "--ubatch-size") { - CHECK_ARG - params.n_ubatch = std::stoi(argv[i]); - return true; - } - if (arg == "--keep") { - CHECK_ARG - params.n_keep = std::stoi(argv[i]); - return true; - } - if (arg == "--draft") { - CHECK_ARG - params.n_draft = std::stoi(argv[i]); - return true; - } - if (arg == "--chunks") { - CHECK_ARG - params.n_chunks = std::stoi(argv[i]); - return true; - } - if (arg == "-np" || arg == "--parallel") { - CHECK_ARG - params.n_parallel = std::stoi(argv[i]); - return true; - } - if (arg == "-ns" || arg == "--sequences") { - CHECK_ARG - params.n_sequences = std::stoi(argv[i]); - return true; - } - if (arg == "--p-split" || arg == "-ps") { - CHECK_ARG - params.p_split = std::stof(argv[i]); - return true; - } - if (arg == "-m" || arg == "--model") { - CHECK_ARG - params.model = argv[i]; - return true; - } - if (arg == "-md" || arg == "--model-draft") { - CHECK_ARG - params.model_draft = argv[i]; - return true; - } - if (arg == "-a" || arg == "--alias") { - CHECK_ARG - params.model_alias = argv[i]; - return true; - } - if (arg == "-mu" || arg == "--model-url") { - CHECK_ARG - params.model_url = argv[i]; - return true; - } - if (arg == "-hft" || arg == "--hf-token") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.hf_token = argv[i]; - return true; - } - if (arg == "-hfr" || arg == "--hf-repo") { - CHECK_ARG - params.hf_repo = argv[i]; - return true; - } - if (arg == "-hff" || arg == "--hf-file") { - CHECK_ARG - params.hf_file = argv[i]; - return true; - } - if (arg == "--lora") { - CHECK_ARG - params.lora_adapters.push_back({ - std::string(argv[i]), - 1.0, - }); - return true; - } - if (arg == "--lora-scaled") { - CHECK_ARG - std::string lora_adapter = argv[i]; - CHECK_ARG - params.lora_adapters.push_back({ - lora_adapter, - std::stof(argv[i]), - }); - return true; - } - if (arg == "--lora-init-without-apply") { - params.lora_init_without_apply = true; - return true; - } - if (arg == "--control-vector") { - CHECK_ARG - params.control_vectors.push_back({ 1.0f, argv[i], }); - return true; - } - if (arg == "--control-vector-scaled") { - CHECK_ARG - const char* fname = argv[i]; - CHECK_ARG - params.control_vectors.push_back({ std::stof(argv[i]), fname, }); - return true; - } - if (arg == "--control-vector-layer-range") { - CHECK_ARG - params.control_vector_layer_start = std::stoi(argv[i]); - CHECK_ARG - params.control_vector_layer_end = std::stoi(argv[i]); - return true; - } - if (arg == "--mmproj") { - CHECK_ARG - params.mmproj = argv[i]; - return true; - } - if (arg == "--image") { - CHECK_ARG - params.image.emplace_back(argv[i]); - return true; - } - if (arg == "-i" || arg == "--interactive") { - params.interactive = true; - return true; - } - if (arg == "-sp" || arg == "--special") { - params.special = true; - return true; - } - if (arg == "--embedding" || arg == "--embeddings") { - params.embedding = true; - return true; - } - if (arg == "--embd-normalize") { - CHECK_ARG - params.embd_normalize = std::stoi(argv[i]); - return true; - } - if (arg == "--embd-output-format") { - CHECK_ARG - params.embd_out = argv[i]; - return true; - } - if (arg == "--embd-separator") { - CHECK_ARG - params.embd_sep = argv[i]; - return true; - } - if (arg == "-if" || arg == "--interactive-first") { - params.interactive_first = true; - return true; - } - if (arg == "-cnv" || arg == "--conversation") { - params.conversation = true; - return true; - } - if (arg == "--infill") { - params.infill = true; - return true; - } - if (arg == "-dkvc" || arg == "--dump-kv-cache") { - params.dump_kv_cache = true; - return true; - } - if (arg == "-nkvo" || arg == "--no-kv-offload") { - params.no_kv_offload = true; - return true; - } - if (arg == "-ctk" || arg == "--cache-type-k") { - params.cache_type_k = argv[++i]; - return true; - } - if (arg == "-ctv" || arg == "--cache-type-v") { - params.cache_type_v = argv[++i]; - return true; - } - if (arg == "-mli" || arg == "--multiline-input") { - params.multiline_input = true; - return true; - } - if (arg == "--simple-io") { - params.simple_io = true; - return true; - } - if (arg == "-cb" || arg == "--cont-batching") { - params.cont_batching = true; - return true; - } - if (arg == "-nocb" || arg == "--no-cont-batching") { - params.cont_batching = false; - return true; - } - if (arg == "-fa" || arg == "--flash-attn") { - params.flash_attn = true; - return true; - } - if (arg == "-co" || arg == "--color") { - params.use_color = true; - return true; - } - if (arg == "--mlock") { - params.use_mlock = true; - return true; - } - if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { - CHECK_ARG - params.n_gpu_layers = std::stoi(argv[i]); - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - return true; - } - if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") { - CHECK_ARG - params.n_gpu_layers_draft = std::stoi(argv[i]); - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - return true; - } - if (arg == "--main-gpu" || arg == "-mg") { - CHECK_ARG - params.main_gpu = std::stoi(argv[i]); -#ifndef LM_GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); -#endif // LM_GGML_USE_CUDA_SYCL_VULKAN - return true; - } - if (arg == "--split-mode" || arg == "-sm") { - CHECK_ARG - std::string arg_next = argv[i]; - if (arg_next == "none") { - params.split_mode = LLAMA_SPLIT_MODE_NONE; - } - else if (arg_next == "layer") { - params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } - else if (arg_next == "row") { -#ifdef LM_GGML_USE_SYCL - fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n"); - exit(1); -#endif // LM_GGML_USE_SYCL - params.split_mode = LLAMA_SPLIT_MODE_ROW; - } - else { - invalid_param = true; - return true; - } -#ifndef LM_GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n"); -#endif // LM_GGML_USE_CUDA_SYCL_VULKAN - return true; - } - if (arg == "--tensor-split" || arg == "-ts") { - CHECK_ARG - std::string arg_next = argv[i]; - - // split string by , and / - const std::regex regex{ R"([,/]+)" }; - std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; - std::vector split_arg{ it, {} }; - if (split_arg.size() >= llama_max_devices()) { - invalid_param = true; - return true; - } - for (size_t i = 0; i < llama_max_devices(); ++i) { - if (i < split_arg.size()) { - params.tensor_split[i] = std::stof(split_arg[i]); - } - else { - params.tensor_split[i] = 0.0f; - } - } -#ifndef LM_GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting a tensor split has no effect.\n"); -#endif // LM_GGML_USE_CUDA_SYCL_VULKAN - return true; - } - if (arg == "--rpc") { - CHECK_ARG - params.rpc_servers = argv[i]; - return true; - } - if (arg == "--no-mmap") { - params.use_mmap = false; - return true; - } - if (arg == "--numa") { - CHECK_ARG - std::string value(argv[i]); - /**/ if (value == "distribute" || value == "") { params.numa = LM_GGML_NUMA_STRATEGY_DISTRIBUTE; } - else if (value == "isolate") { params.numa = LM_GGML_NUMA_STRATEGY_ISOLATE; } - else if (value == "numactl") { params.numa = LM_GGML_NUMA_STRATEGY_NUMACTL; } - else { invalid_param = true; } - return true; - } - if (arg == "-v" || arg == "--verbose") { - params.verbosity = 1; - return true; - } - if (arg == "--verbosity") { - CHECK_ARG - params.verbosity = std::stoi(argv[i]); - return true; - } - if (arg == "--verbose-prompt") { - params.verbose_prompt = true; - return true; - } - if (arg == "--no-display-prompt") { - params.display_prompt = false; - return true; - } - if (arg == "-r" || arg == "--reverse-prompt") { - CHECK_ARG - params.antiprompt.emplace_back(argv[i]); - return true; - } - if (arg == "-ld" || arg == "--logdir") { - CHECK_ARG - params.logdir = argv[i]; - - if (params.logdir.back() != DIRECTORY_SEPARATOR) { - params.logdir += DIRECTORY_SEPARATOR; - } - return true; - } - if (arg == "-lcs" || arg == "--lookup-cache-static") { - CHECK_ARG - params.lookup_cache_static = argv[i]; - return true; - } - if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { - CHECK_ARG - params.lookup_cache_dynamic = argv[i]; - return true; - } - if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { - CHECK_ARG - params.logits_file = argv[i]; - return true; - } - if (arg == "--perplexity" || arg == "--all-logits") { - params.logits_all = true; - return true; - } - if (arg == "--ppl-stride") { - CHECK_ARG - params.ppl_stride = std::stoi(argv[i]); - return true; - } - if (arg == "--ppl-output-type") { - CHECK_ARG - params.ppl_output_type = std::stoi(argv[i]); - return true; - } - if (arg == "-ptc" || arg == "--print-token-count") { - CHECK_ARG - params.n_print = std::stoi(argv[i]); - return true; - } - if (arg == "--check-tensors") { - params.check_tensors = true; - return true; - } - if (arg == "--hellaswag") { - params.hellaswag = true; - return true; - } - if (arg == "--hellaswag-tasks") { - CHECK_ARG - params.hellaswag_tasks = std::stoi(argv[i]); - return true; - } - if (arg == "--winogrande") { - params.winogrande = true; - return true; - } - if (arg == "--winogrande-tasks") { - CHECK_ARG - params.winogrande_tasks = std::stoi(argv[i]); - return true; - } - if (arg == "--multiple-choice") { - params.multiple_choice = true; - return true; - } - if (arg == "--multiple-choice-tasks") { - CHECK_ARG - params.multiple_choice_tasks = std::stoi(argv[i]); - return true; - } - if (arg == "--kl-divergence") { - params.kl_divergence = true; - return true; - } - if (arg == "--ignore-eos") { - params.ignore_eos = true; - return true; - } - if (arg == "--penalize-nl") { - sparams.penalize_nl = true; - return true; - } - if (arg == "-l" || arg == "--logit-bias") { - CHECK_ARG - std::stringstream ss(argv[i]); - llama_token key; - char sign; - std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - } - else { - throw std::exception(); - } - } - catch (const std::exception&) { - invalid_param = true; - return true; - } - return true; - } - if (arg == "-h" || arg == "--help" || arg == "--usage" ) { - params.usage = true; - return true; - } - if (arg == "--version") { - fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); - exit(0); - } - if (arg == "--in-prefix-bos") { - params.input_prefix_bos = true; - params.enable_chat_template = false; - return true; - } - if (arg == "--in-prefix") { - CHECK_ARG - params.input_prefix = argv[i]; - params.enable_chat_template = false; - return true; - } - if (arg == "--in-suffix") { - CHECK_ARG - params.input_suffix = argv[i]; - params.enable_chat_template = false; - return true; - } - if (arg == "--spm-infill") { - params.spm_infill = true; - return true; - } - if (arg == "--grammar") { - CHECK_ARG - sparams.grammar = argv[i]; - return true; - } - if (arg == "--grammar-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; + cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); + int intel_atom = 0x20; + int core_type = (eax & 0xff000000u) >> 24; + return core_type == intel_atom; +} + +static int cpu_count_math_cpus(int n_cpu) { + int result = 0; + for (int cpu = 0; cpu < n_cpu; ++cpu) { + if (pin_cpu(cpu)) { + return -1; } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(sparams.grammar) - ); - return true; - } - if (arg == "-j" || arg == "--json-schema") { - CHECK_ARG - sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); - return true; - } - if (arg == "--override-kv") { - CHECK_ARG - if (!string_parse_kv_override(argv[i], params.kv_overrides)) { - fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); - invalid_param = true; - return true; + if (is_running_on_efficiency_core()) { + continue; // efficiency cores harm lockstep threading } - return true; - } - if (arg == "--host") { - CHECK_ARG - params.hostname = argv[i]; - return true; - } - if (arg == "--port") { - CHECK_ARG - params.port = std::stoi(argv[i]); - return true; - } - if (arg == "--path") { - CHECK_ARG - params.public_path = argv[i]; - return true; + ++cpu; // hyperthreading isn't useful for linear algebra + ++result; } - if (arg == "--api-key") { - CHECK_ARG - params.api_keys.push_back(argv[i]); - return true; + return result; +} + +#endif // __x86_64__ && __linux__ + +/** + * Returns number of CPUs on system that are useful for math. + */ +int32_t cpu_get_num_math() { +#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) + int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); + if (n_cpu < 1) { + return cpu_get_num_physical_cores(); } - if (arg == "--api-key-file") { - CHECK_ARG - std::ifstream key_file(argv[i]); - if (!key_file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::string key; - while (std::getline(key_file, key)) { - if (!key.empty()) { - params.api_keys.push_back(key); + if (is_hybrid_cpu()) { + cpu_set_t affinity; + if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { + int result = cpu_count_math_cpus(n_cpu); + pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); + if (result > 0) { + return result; } } - key_file.close(); - return true; - } - if (arg == "--ssl-key-file") { - CHECK_ARG - params.ssl_file_key = argv[i]; - return true; - } - if (arg == "--ssl-cert-file") { - CHECK_ARG - params.ssl_file_cert = argv[i]; - return true; } - if (arg == "--timeout" || arg == "-to") { - CHECK_ARG - params.timeout_read = std::stoi(argv[i]); - params.timeout_write = std::stoi(argv[i]); +#endif + return cpu_get_num_physical_cores(); +} + +// Helper for setting process priority + +#if defined(_WIN32) + +bool set_process_priority(enum lm_ggml_sched_priority prio) { + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { return true; } - if (arg == "--threads-http") { - CHECK_ARG - params.n_threads_http = std::stoi(argv[i]); - return true; + + DWORD p = NORMAL_PRIORITY_CLASS; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; + case LM_GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; + case LM_GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; + case LM_GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break; } - if (arg == "-spf" || arg == "--system-prompt-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::string system_prompt; - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(system_prompt) - ); - params.system_prompt = system_prompt; - return true; + + if (!SetPriorityClass(GetCurrentProcess(), p)) { + LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); + return false; } - if (arg == "--log-format") { - CHECK_ARG - if (std::strcmp(argv[i], "json") == 0) { - params.log_json = true; - } else if (std::strcmp(argv[i], "text") == 0) { - params.log_json = false; - } else { - invalid_param = true; - return true; - } + + return true; +} + +#else // MacOS and POSIX +#include +#include + +bool set_process_priority(enum lm_ggml_sched_priority prio) { + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { return true; } - if (arg == "--no-slots") { - params.endpoint_slots = false; - return true; + + int p = 0; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: p = 0; break; + case LM_GGML_SCHED_PRIO_MEDIUM: p = -5; break; + case LM_GGML_SCHED_PRIO_HIGH: p = -10; break; + case LM_GGML_SCHED_PRIO_REALTIME: p = -20; break; } - if (arg == "--metrics") { - params.endpoint_metrics = true; - return true; + + if (!setpriority(PRIO_PROCESS, 0, p)) { + LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); + return false; } - if (arg == "--slot-save-path") { - CHECK_ARG - params.slot_save_path = argv[i]; - // if doesn't end with DIRECTORY_SEPARATOR, add it - if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { - params.slot_save_path += DIRECTORY_SEPARATOR; + return true; +} + +#endif + +// +// CLI argument parsing +// + + +void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) { + int32_t n_set = 0; + + if (cpuparams.n_threads < 0) { + // Assuming everything about cpuparams is invalid + if (role_model != nullptr) { + cpuparams = *role_model; + } else { + cpuparams.n_threads = cpu_get_num_math(); } - return true; } - if (arg == "--chat-template") { - CHECK_ARG - if (!llama_chat_verify_template(argv[i])) { - fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); - fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); - invalid_param = true; - return true; + + for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + if (cpuparams.cpumask[i]) { + n_set++; } - params.chat_template = argv[i]; - return true; - } - if (arg == "--slot-prompt-similarity" || arg == "-sps") { - CHECK_ARG - params.slot_prompt_similarity = std::stof(argv[i]); - return true; - } - if (arg == "-pps") { - params.is_pp_shared = true; - return true; - } - if (arg == "-npp") { - CHECK_ARG - auto p = string_split(argv[i], split_delim); - params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); - return true; } - if (arg == "-ntg") { - CHECK_ARG - auto p = string_split(argv[i], split_delim); - params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); - return true; + + if (n_set && n_set < cpuparams.n_threads) { + // Not enough set bits, may experience performance issues. + LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); } - if (arg == "-npl") { - CHECK_ARG - auto p = string_split(argv[i], split_delim); - params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); - return true; +} + +bool parse_cpu_range(const std::string & range, bool (&boolmask)[LM_GGML_MAX_N_THREADS]) { + size_t dash_loc = range.find('-'); + if (dash_loc == std::string::npos) { + LOG_ERR("Format of CPU range is invalid! Expected []-[].\n"); + return false; } - if (arg == "--context-file") { - CHECK_ARG - std::ifstream file(argv[i], std::ios::binary); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; + + size_t start_i; + size_t end_i; + + if (dash_loc == 0) { + start_i = 0; + } else { + start_i = std::stoull(range.substr(0, dash_loc)); + if (start_i >= LM_GGML_MAX_N_THREADS) { + LOG_ERR("Start index out of bounds!\n"); + return false; } - params.context_files.push_back(argv[i]); - return true; - } - if (arg == "--chunk-size") { - CHECK_ARG - params.chunk_size = std::stoi(argv[i]); - return true; - } - if (arg == "--chunk-separator") { - CHECK_ARG - params.chunk_separator = argv[i]; - return true; - } - if (arg == "--junk") { - CHECK_ARG - params.n_junk = std::stoi(argv[i]); - return true; - } - if (arg == "--pos") { - CHECK_ARG - params.i_pos = std::stoi(argv[i]); - return true; - } - if (arg == "-o" || arg == "--output" || arg == "--output-file") { - CHECK_ARG - params.out_file = argv[i]; - params.cvector_outfile = argv[i]; - params.lora_outfile = argv[i]; - return true; - } - if (arg == "-ofreq" || arg == "--output-frequency") { - CHECK_ARG - params.n_out_freq = std::stoi(argv[i]); - return true; - } - if (arg == "--save-frequency") { - CHECK_ARG - params.n_save_freq = std::stoi(argv[i]); - return true; - } - if (arg == "--process-output") { - params.process_output = true; - return true; - } - if (arg == "--no-ppl") { - params.compute_ppl = false; - return true; - } - if (arg == "--chunk" || arg == "--from-chunk") { - CHECK_ARG - params.i_chunk = std::stoi(argv[i]); - return true; - } - // cvector params - if (arg == "--positive-file") { - CHECK_ARG - params.cvector_positive_file = argv[i]; - return true; - } - if (arg == "--negative-file") { - CHECK_ARG - params.cvector_negative_file = argv[i]; - return true; - } - if (arg == "--pca-batch") { - CHECK_ARG - params.n_pca_batch = std::stoi(argv[i]); - return true; - } - if (arg == "--pca-iter") { - CHECK_ARG - params.n_pca_iterations = std::stoi(argv[i]); - return true; - } - if (arg == "--method") { - CHECK_ARG - std::string value(argv[i]); - /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } - else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } - else { invalid_param = true; } - return true; - } - if (arg == "--no-warmup") { - params.warmup = false; - return true; - } -#ifndef LOG_DISABLE_LOGS - // Parse args for logging parameters - if (log_param_single_parse(argv[i])) { - // Do nothing, log_param_single_parse automatically does it's thing - // and returns if a match was found and parsed. - return true; } - if (log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i])) { - // We have a matching known parameter requiring an argument, - // now we need to check if there is anything after this argv - // and flag invalid_param or parse it. - CHECK_ARG - if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) { - invalid_param = true; - return true; + + if (dash_loc == range.length() - 1) { + end_i = LM_GGML_MAX_N_THREADS - 1; + } else { + end_i = std::stoull(range.substr(dash_loc + 1)); + if (end_i >= LM_GGML_MAX_N_THREADS) { + LOG_ERR("End index out of bounds!\n"); + return false; } - return true; } - // End of Parse args for logging parameters -#endif // LOG_DISABLE_LOGS - return false; + for (size_t i = start_i; i <= end_i; i++) { + boolmask[i] = true; + } + + return true; } -#ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) -#endif +bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[LM_GGML_MAX_N_THREADS]) { + // Discard potential 0x prefix + size_t start_i = 0; + if (mask.length() >= 2 && mask.substr(0, 2) == "0x") { + start_i = 2; + } -void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sparams; + size_t num_digits = mask.length() - start_i; + if (num_digits > 128) num_digits = 128; - std::string sampler_type_chars; - std::string sampler_type_names; - for (const auto sampler_type : sparams.samplers_sequence) { - sampler_type_chars += static_cast(sampler_type); - sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";"; - } - sampler_type_names.pop_back(); + size_t end_i = num_digits + start_i; - struct option_info { - LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) - option_info(const std::string & tags, const char * args, const char * desc, ...) : tags(tags), args(args), desc(desc) { - va_list args_list; - va_start(args_list, desc); - char buffer[1024]; - vsnprintf(buffer, sizeof(buffer), desc, args_list); - va_end(args_list); - this->desc = buffer; + for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) { + char c = mask.at(i); + int8_t id = c; + + if ((c >= '0' && c <= '9')) { + id -= '0'; + } else if (c >= 'a' && c <= 'f') { + id -= 'a' - 10; + } else if (c >= 'A' && c <= 'F') { + id -= 'A' - 10; + } else { + LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i)); + return false; } - option_info(const std::string & grp) : grp(grp) {} + boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0); + boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0); + boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0); + boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0); + } - std::string tags; - std::string args; - std::string desc; - std::string grp; - }; + return true; +} + +void common_init() { + llama_log_set([](lm_ggml_log_level level, const char * text, void * /*user_data*/) { + if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + common_log_add(common_log_main(), level, "%s", text); + } + }, NULL); + +#ifdef NDEBUG + const char * build_type = ""; +#else + const char * build_type = " (debug)"; +#endif - std::vector options; - - // TODO: filter by tags - - options.push_back({ "general" }); - options.push_back({ "*", "-h, --help, --usage", "print usage and exit" }); - options.push_back({ "*", " --version", "show version and build info" }); - options.push_back({ "*", "-v, --verbose", "print verbose information" }); - options.push_back({ "*", " --verbosity N", "set specific verbosity level (default: %d)", params.verbosity }); - options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); - options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); - options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); - options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); - options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); - options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); - options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); - options.push_back({ "speculative", "-tbd, --threads-batch-draft N", - "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); - options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); - options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); - options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", - "path to static lookup cache to use for lookup decoding (not updated by generation)" }); - options.push_back({ "*", "-lcd, --lookup-cache-dynamic FNAME", - "path to dynamic lookup cache to use for lookup decoding (updated by generation)" }); - - options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); - options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); - options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); - options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); - options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); - options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); - options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); - options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" - "in conversation mode, this will be used as system prompt\n" - "(default: '%s')", params.prompt.c_str() }); - options.push_back({ "*", "-f, --file FNAME", "a file containing the prompt (default: none)" }); - options.push_back({ "*", " --in-file FNAME", "an input file (repeat to specify multiple files)" }); - options.push_back({ "*", "-bf, --binary-file FNAME", "binary file containing the prompt (default: none)" }); - options.push_back({ "*", "-e, --escape", "process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false" }); - options.push_back({ "*", " --no-escape", "do not process escape sequences" }); - options.push_back({ "main", "-ptc, --print-token-count N", "print token count every N tokens (default: %d)", params.n_print }); - options.push_back({ "main", " --prompt-cache FNAME", "file to cache prompt state for faster startup (default: none)" }); - options.push_back({ "main", " --prompt-cache-all", "if specified, saves user input and generations to cache as well\n" - "not supported with --interactive or other interactive options" }); - options.push_back({ "main", " --prompt-cache-ro", "if specified, uses the prompt cache but does not update it" }); - options.push_back({ "main", "-r, --reverse-prompt PROMPT", - "halt generation at PROMPT, return control in interactive mode\n" - "can be specified more than once for multiple prompts" }); - options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); - options.push_back({ "main", "-cnv, --conversation", "run in conversation mode, does not print special tokens and suffix/prefix\n" - "if suffix/prefix are not specified, default chat template will be used\n" - "(default: %s)", params.conversation ? "true" : "false" }); - options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); - options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); - options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); - options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" }); - options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" }); - options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" }); - options.push_back({ "main", " --no-warmup", "skip warming up the model with an empty run" }); - options.push_back({ "server infill", - " --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" }); - - options.push_back({ "sampling" }); - options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" - "(default: %s)", sampler_type_names.c_str() }); - options.push_back({ "*", " --sampling-seq SEQUENCE", - "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); - options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); - options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); - options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp }); - options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k }); - options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); - options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); - options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); - options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p }); - options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n }); - options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); - options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); - options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); - options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); - options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); - options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" - "Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" - "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", sparams.mirostat }); - options.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta }); - options.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau }); - options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" - "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" - "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); - options.push_back({ "main", " --cfg-negative-prompt PROMPT", - "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); - options.push_back({ "main", " --cfg-negative-prompt-file FNAME", - "negative prompt file to use for guidance" }); - options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); - options.push_back({ "main", " --chat-template JINJA_TEMPLATE", - "set custom jinja chat template (default: template taken from model's metadata)\n" - "if suffix/prefix are specified, template will be disabled\n" - "only commonly used templates are accepted:\n" - "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); - options.push_back({ "grammar" }); - options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); - options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); - options.push_back({ "*", "-j, --json-schema SCHEMA", - "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\n" - "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" }); - - options.push_back({ "embedding" }); - options.push_back({ "embedding", " --pooling {none,mean,cls,last}", - "pooling type for embeddings, use model default if unspecified" }); - options.push_back({ "embedding", " --attention {causal,non-causal}", - "attention type for embeddings, use model default if unspecified" }); - - options.push_back({ "context hacking" }); - options.push_back({ "*", " --rope-scaling {none,linear,yarn}", - "RoPE frequency scaling method, defaults to linear unless specified by the model" }); - options.push_back({ "*", " --rope-scale N", "RoPE context scaling factor, expands context by a factor of N" }); - options.push_back({ "*", " --rope-freq-base N", "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)" }); - options.push_back({ "*", " --rope-freq-scale N", "RoPE frequency scaling factor, expands context by a factor of 1/N" }); - options.push_back({ "*", " --yarn-orig-ctx N", "YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx }); - options.push_back({ "*", " --yarn-ext-factor N", "YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor }); - options.push_back({ "*", " --yarn-attn-factor N", "YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor }); - options.push_back({ "*", " --yarn-beta-slow N", "YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow }); - options.push_back({ "*", " --yarn-beta-fast N", "YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast }); - options.push_back({ "*", "-gan, --grp-attn-n N", "group-attention factor (default: %d)", params.grp_attn_n }); - options.push_back({ "*", "-gaw, --grp-attn-w N", "group-attention width (default: %.1f)", (double)params.grp_attn_w }); - options.push_back({ "*", "-dkvc, --dump-kv-cache", "verbose print of the KV cache" }); - options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" }); - options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() }); - options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() }); - - options.push_back({ "perplexity" }); - options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" }); - options.push_back({ "perplexity", " --hellaswag", "compute HellaSwag score over random tasks from datafile supplied with -f" }); - options.push_back({ "perplexity", " --hellaswag-tasks N", "number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks }); - options.push_back({ "perplexity", " --winogrande", "compute Winogrande score over random tasks from datafile supplied with -f" }); - options.push_back({ "perplexity", " --winogrande-tasks N", "number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks }); - options.push_back({ "perplexity", " --multiple-choice", "compute multiple choice score over random tasks from datafile supplied with -f" }); - options.push_back({ "perplexity", " --multiple-choice-tasks N", - "number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks }); - options.push_back({ "perplexity", " --kl-divergence", "computes KL-divergence to logits provided via --kl-divergence-base" }); - options.push_back({ "perplexity", " --ppl-stride N", "stride for perplexity calculation (default: %d)", params.ppl_stride }); - options.push_back({ "perplexity", " --ppl-output-type {0,1}", - "output type for perplexity calculation (default: %d)", params.ppl_output_type }); - - options.push_back({ "parallel" }); - options.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold }); - options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel }); - options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences }); - options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" }); - options.push_back({ "*", "-nocb, --no-cont-batching", "disable continuous batching" }); - - options.push_back({ "multi-modality" }); - options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); - options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); - - options.push_back({ "backend" }); - options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); - - if (llama_supports_mlock()) { - options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); - } - if (llama_supports_mmap()) { - options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); - } - options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n" - " - distribute: spread execution evenly over all nodes\n" - " - isolate: only spawn threads on CPUs on the node that execution started on\n" - " - numactl: use the CPU map provided by numactl\n" - "if run without this previously, it is recommended to drop the system page cache before using this\n" - "see https://github.com/ggerganov/llama.cpp/issues/1437" }); - - if (llama_supports_gpu_offload()) { - options.push_back({ "*", "-ngl, --gpu-layers N", - "number of layers to store in VRAM" }); - options.push_back({ "*", "-ngld, --gpu-layers-draft N", - "number of layers to store in VRAM for the draft model" }); - options.push_back({ "*", "-sm, --split-mode SPLIT_MODE", - "how to split the model across multiple GPUs, one of:\n" - " - none: use one GPU only\n" - " - layer (default): split layers and KV across GPUs\n" - " - row: split rows across GPUs" }); - options.push_back({ "*", "-ts, --tensor-split SPLIT", - "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" }); - options.push_back({ "*", "-mg, --main-gpu i", "the GPU to use for the model (with split-mode = none),\n" - "or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu }); - } - - options.push_back({ "model" }); - options.push_back({ "*", " --check-tensors", "check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false" }); - options.push_back({ "*", " --override-kv KEY=TYPE:VALUE", - "advanced option to override model metadata by key. may be specified multiple times.\n" - "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" }); - options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (can be repeated to use multiple adapters)" }); - options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" }); - options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" - "note: this argument can be repeated to add multiple control vectors" }); - options.push_back({ "*", " --control-vector-scaled FNAME SCALE", - "add a control vector with user defined scaling SCALE\n" - "note: this argument can be repeated to add multiple scaled control vectors" }); - options.push_back({ "*", " --control-vector-layer-range START END", - "layer range to apply the control vector(s) to, start and end inclusive" }); - options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" - "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH }); - options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" }); - options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); - options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); - options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); - options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); - - options.push_back({ "retrieval" }); - options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); - options.push_back({ "retrieval", " --chunk-size N", "minimum length of embedded text chunks (default: %d)", params.chunk_size }); - options.push_back({ "retrieval", " --chunk-separator STRING", - "separator between chunks (default: '%s')", params.chunk_separator.c_str() }); - - options.push_back({ "passkey" }); - options.push_back({ "passkey", " --junk N", "number of times to repeat the junk text (default: %d)", params.n_junk }); - options.push_back({ "passkey", " --pos N", "position of the passkey in the junk text (default: %d)", params.i_pos }); - - options.push_back({ "imatrix" }); - options.push_back({ "imatrix", "-o, --output FNAME", "output file (default: '%s')", params.out_file.c_str() }); - options.push_back({ "imatrix", " --output-frequency N", "output the imatrix every N iterations (default: %d)", params.n_out_freq }); - options.push_back({ "imatrix", " --save-frequency N", "save an imatrix copy every N iterations (default: %d)", params.n_save_freq }); - options.push_back({ "imatrix", " --process-output", "collect data for the output tensor (default: %s)", params.process_output ? "true" : "false" }); - options.push_back({ "imatrix", " --no-ppl", "do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false" }); - options.push_back({ "imatrix", " --chunk N", "start processing the input from chunk N (default: %d)", params.i_chunk }); - - options.push_back({ "bench" }); - options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" }); - options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" }); - options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" }); - options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" }); - - options.push_back({ "embedding" }); - options.push_back({ "embedding", " --embd-normalize", "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize }); - options.push_back({ "embedding", " --embd-output-format", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" }); - options.push_back({ "embedding", " --embd-separator", "separator of embendings (default \\n) for example \"<#sep#>\"" }); - - options.push_back({ "server" }); - options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); - options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); - options.push_back({ "server", " --path PATH", "path to serve static files from (default: %s)", params.public_path.c_str() }); - options.push_back({ "server", " --embedding(s)", "restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled" }); - options.push_back({ "server", " --api-key KEY", "API key to use for authentication (default: none)" }); - options.push_back({ "server", " --api-key-file FNAME", "path to file containing API keys (default: none)" }); - options.push_back({ "server", " --ssl-key-file FNAME", "path to file a PEM-encoded SSL private key" }); - options.push_back({ "server", " --ssl-cert-file FNAME", "path to file a PEM-encoded SSL certificate" }); - options.push_back({ "server", " --timeout N", "server read/write timeout in seconds (default: %d)", params.timeout_read }); - options.push_back({ "server", " --threads-http N", "number of threads used to process HTTP requests (default: %d)", params.n_threads_http }); - options.push_back({ "server", " --system-prompt-file FNAME", - "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" }); - options.push_back({ "server", " --log-format {text,json}", - "log output format: json or text (default: json)" }); - options.push_back({ "server", " --metrics", "enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled" }); - options.push_back({ "server", " --no-slots", "disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled" }); - options.push_back({ "server", " --slot-save-path PATH", "path to save slot kv cache (default: disabled)" }); - options.push_back({ "server", " --chat-template JINJA_TEMPLATE", - "set custom jinja chat template (default: template taken from model's metadata)\n" - "only commonly used templates are accepted:\n" - "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); - options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", - "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); - options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"}); - -#ifndef LOG_DISABLE_LOGS - options.push_back({ "logging" }); - options.push_back({ "*", " --simple-io", "use basic IO for better compatibility in subprocesses and limited consoles" }); - options.push_back({ "*", "-ld, --logdir LOGDIR", "path under which to save YAML logs (no logging if unset)" }); - options.push_back({ "logging", " --log-test", "Run simple logging test" }); - options.push_back({ "logging", " --log-disable", "Disable trace logs" }); - options.push_back({ "logging", " --log-enable", "Enable trace logs" }); - options.push_back({ "logging", " --log-file FNAME", "Specify a log filename (without extension)" }); - options.push_back({ "logging", " --log-new", "Create a separate new log file on start. " - "Each log file will have unique name: \"..log\"" }); - options.push_back({ "logging", " --log-append", "Don't truncate the old log file." }); -#endif // LOG_DISABLE_LOGS - - options.push_back({ "cvector" }); - options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); - options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); - options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); - options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); - options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); - options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); - - options.push_back({ "export-lora" }); - options.push_back({ "export-lora", "-m, --model", "model path from which to load base model (default '%s')", params.model.c_str() }); - options.push_back({ "export-lora", " --lora FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)" }); - options.push_back({ "export-lora", " --lora-scaled FNAME S", "path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" }); - options.push_back({ "*", "-t, --threads N", "number of threads to use during computation (default: %d)", params.n_threads }); - options.push_back({ "export-lora", "-o, --output FNAME", "output file (default: '%s')", params.lora_outfile.c_str() }); - - printf("usage: %s [options]\n", argv[0]); - - for (const auto & o : options) { - if (!o.grp.empty()) { - printf("\n%s:\n\n", o.grp.c_str()); - continue; - } - printf(" %-32s", o.args.c_str()); - if (o.args.length() > 30) { - printf("\n%34s", ""); - } - - const auto desc = o.desc; - size_t start = 0; - size_t end = desc.find('\n'); - while (end != std::string::npos) { - printf("%s\n%34s", desc.substr(start, end - start).c_str(), ""); - start = end + 1; - end = desc.find('\n', start); - } - - printf("%s\n", desc.substr(start).c_str()); - } - printf("\n"); + LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type); } -std::string gpt_params_get_system_info(const gpt_params & params) { +std::string common_params_get_system_info(const common_params & params) { std::ostringstream os; - os << "system_info: n_threads = " << params.n_threads; - if (params.n_threads_batch != -1) { - os << " (n_threads_batch = " << params.n_threads_batch << ")"; + os << "system_info: n_threads = " << params.cpuparams.n_threads; + if (params.cpuparams_batch.n_threads != -1) { + os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")"; } #if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later // TODO: windows + arm64 + mingw64 @@ -1824,17 +407,19 @@ std::string gpt_params_get_system_info(const gpt_params & params) { // String utils // -std::vector string_split(std::string input, char separator) { - std::vector parts; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(0, separator_pos); - parts.emplace_back(part); - input = input.substr(separator_pos + 1); - separator_pos = input.find(separator); - } - parts.emplace_back(input); - return parts; +std::string string_format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + LM_GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + LM_GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); } std::string string_strip(const std::string & str) { @@ -1867,13 +452,107 @@ std::string string_get_sortable_timestamp() { void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +std::string string_from(bool value) { + return value ? "true" : "false"; +} + +std::string string_from(const std::vector & values) { + std::stringstream buf; + + buf << "[ "; + bool first = true; + for (auto e : values) { + if (first) { + first = false; + } else { + buf << ", "; + } + buf << std::to_string(e); + } + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const std::vector & tokens) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (const auto & token : tokens) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, token); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "'" << detokenized << "'" + << ":" << std::to_string(token); + } + + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, batch.token[i]); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "\n" << std::to_string(i) + << ":token '" << detokenized << "'" + << ":pos " << std::to_string(batch.pos[i]) + << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ":seq_id " << std::to_string(batch.seq_id[i][0]) + << ":logits " << std::to_string(batch.logits[i]); } + + buf << " ]"; + + return buf.str(); } void string_process_escapes(std::string & input) { @@ -1916,7 +595,7 @@ void string_process_escapes(std::string & input) { bool string_parse_kv_override(const char * data, std::vector & overrides) { const char * sep = strchr(data, '='); if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + LOG_ERR("%s: malformed KV override '%s'\n", __func__, data); return false; } llama_model_kv_override kvo; @@ -1939,20 +618,20 @@ bool string_parse_kv_override(const char * data, std::vector 127) { - fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); return false; } strncpy(kvo.val_str, sep, 127); kvo.val_str[127] = '\0'; } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data); return false; } overrides.emplace_back(std::move(kvo)); @@ -2149,30 +828,55 @@ std::string fs_get_cache_file(const std::string & filename) { // // Model utils // -struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { - llama_init_result iparams; - auto mparams = llama_model_params_from_gpt_params(params); +struct common_init_result common_init_from_params(common_params & params) { + common_init_result iparams; + auto mparams = common_model_params_to_llama(params); llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); return iparams; } - auto cparams = llama_context_params_from_gpt_params(params); + if (params.reranking) { + bool ok = true; + + if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); + ok = false; + } + + if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); + ok = false; + } + + if (!ok) { + llama_free_model(model); + + return iparams; + } + } + + auto cparams = common_context_params_to_llama(params); llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); return iparams; } @@ -2181,10 +885,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); - const auto cvec = llama_control_vector_load(params.control_vectors); + const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); llama_free_model(model); + return iparams; } @@ -2197,18 +902,19 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (err) { llama_free(lctx); llama_free_model(model); + return iparams; } } // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - llama_lora_adapter_container loaded_la; + common_lora_adapter_container loaded_la; loaded_la.path = la.path; loaded_la.scale = la.scale; loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); if (loaded_la.adapter == nullptr) { - fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); llama_free_model(model); return iparams; @@ -2216,27 +922,33 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters } if (!params.lora_init_without_apply) { - llama_lora_adapters_apply(lctx, iparams.lora_adapters); + common_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sparams.ignore_eos = false; } if (params.warmup) { - LOG("warming up the model with an empty run\n"); + LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; llama_token bos = llama_token_bos(model); llama_token eos = llama_token_eos(model); // some models (e.g. T5) don't have a BOS token - if (bos != -1) { + if (bos != LLAMA_TOKEN_NULL) { tmp.push_back(bos); } - tmp.push_back(eos); + if (eos != LLAMA_TOKEN_NULL) { + tmp.push_back(eos); + } + if (tmp.empty()) { + tmp.push_back(0); + } if (llama_model_has_encoder(model)) { - llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == -1) { decoder_start_token_id = bos; @@ -2245,19 +957,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx); + llama_perf_context_reset(lctx); } iparams.model = model; iparams.context = lctx; + return iparams; } -void llama_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { llama_lora_adapter_clear(ctx); for (auto & la : lora_adapters) { if (la.scale != 0.0f) { @@ -2266,12 +979,14 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vector 0) { + LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + + return false; +} + +static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); if (!curl) { - fprintf(stderr, "%s: error initializing libcurl\n", __func__); + LOG_ERR("%s: error initializing libcurl\n", __func__); return false; } @@ -2405,11 +1169,11 @@ static bool llama_download_file(const std::string & url, const std::string & pat if (metadata_in.good()) { try { metadata_in >> metadata; - fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); if (metadata.contains("url") && metadata.at("url").is_string()) { auto previous_url = metadata.at("url").get(); if (previous_url != url) { - fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); return false; } } @@ -2420,24 +1184,24 @@ static bool llama_download_file(const std::string & url, const std::string & pat last_modified = metadata.at("lastModified"); } } catch (const nlohmann::json::exception & e) { - fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); return false; } } } else { - fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str()); + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } // Send a HEAD request to retrieve the etag and last-modified headers - struct llama_load_model_from_url_headers { + struct common_load_model_from_url_headers { std::string etag; std::string last_modified; }; - llama_load_model_from_url_headers headers; + common_load_model_from_url_headers headers; { typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; + common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata; static std::regex header_regex("([^:]+): (.*)\r\n"); static std::regex etag_regex("ETag", std::regex_constants::icase); @@ -2462,9 +1226,8 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - CURLcode res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { return false; } @@ -2474,26 +1237,26 @@ static bool llama_download_file(const std::string & url, const std::string & pat // HEAD not supported, we don't know if the file has changed // force trigger downloading force_download = true; - fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); } } bool should_download = !file_exists || force_download; if (!should_download) { if (!etag.empty() && etag != headers.etag) { - fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); should_download = true; } else if (!last_modified.empty() && last_modified != headers.last_modified) { - fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); should_download = true; } } if (should_download) { std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { - fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); if (remove(path.c_str()) != 0) { - fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str()); + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); return false; } } @@ -2508,7 +1271,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); if (!outfile) { - fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; } @@ -2539,18 +1302,17 @@ static bool llama_download_file(const std::string & url, const std::string & pat }; // start the download - fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - auto res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { return false; } long http_code = 0; curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code < 200 || http_code >= 400) { - fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); return false; } @@ -2564,10 +1326,10 @@ static bool llama_download_file(const std::string & url, const std::string & pat {"lastModified", headers.last_modified} }); std::ofstream(metadata_path) << metadata.dump(4); - fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); if (rename(path_temporary.c_str(), path.c_str()) != 0) { - fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } } @@ -2575,18 +1337,18 @@ static bool llama_download_file(const std::string & url, const std::string & pat return true; } -struct llama_model * llama_load_model_from_url( +struct llama_model * common_load_model_from_url( const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params) { // Basic validation of the model_url if (!model_url || strlen(model_url) == 0) { - fprintf(stderr, "%s: invalid model_url\n", __func__); + LOG_ERR("%s: invalid model_url\n", __func__); return NULL; } - if (!llama_download_file(model_url, path_model, hf_token)) { + if (!common_download_file(model_url, path_model, hf_token)) { return NULL; } @@ -2599,7 +1361,7 @@ struct llama_model * llama_load_model_from_url( }; auto * ctx_gguf = lm_gguf_init_from_file(path_model, lm_gguf_params); if (!ctx_gguf) { - fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model); + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, path_model); return NULL; } @@ -2619,14 +1381,12 @@ struct llama_model * llama_load_model_from_url( // and extract split URL and PATH prefixes { if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) { - fprintf(stderr, "\n%s: unexpected model file name: %s" - " n_split=%d\n", __func__, path_model, n_split); + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, path_model, n_split); return NULL; } if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) { - fprintf(stderr, "\n%s: unexpected model url: %s" - " n_split=%d\n", __func__, model_url, n_split); + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url, n_split); return NULL; } } @@ -2641,7 +1401,7 @@ struct llama_model * llama_load_model_from_url( char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - return llama_download_file(split_url, split_path, hf_token); + return common_download_file(split_url, split_path, hf_token); }, idx)); } @@ -2656,7 +1416,7 @@ struct llama_model * llama_load_model_from_url( return llama_load_model_from_file(path_model, params); } -struct llama_model * llama_load_model_from_hf( +struct llama_model * common_load_model_from_hf( const char * repo, const char * model, const char * path_model, @@ -2676,27 +1436,27 @@ struct llama_model * llama_load_model_from_hf( model_url += "/resolve/main/"; model_url += model; - return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params); + return common_load_model_from_url(model_url.c_str(), path_model, hf_token, params); } #else -struct llama_model * llama_load_model_from_url( +struct llama_model * common_load_model_from_url( const char * /*model_url*/, const char * /*path_model*/, const char * /*hf_token*/, const struct llama_model_params & /*params*/) { - fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; } -struct llama_model * llama_load_model_from_hf( +struct llama_model * common_load_model_from_hf( const char * /*repo*/, const char * /*model*/, const char * /*path_model*/, const char * /*hf_token*/, const struct llama_model_params & /*params*/) { - fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; } @@ -2706,16 +1466,18 @@ struct llama_model * llama_load_model_from_hf( // Batch utils // -void llama_batch_clear(struct llama_batch & batch) { +void common_batch_clear(struct llama_batch & batch) { batch.n_tokens = 0; } -void llama_batch_add( +void common_batch_add( struct llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { + LM_GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); @@ -2731,15 +1493,15 @@ void llama_batch_add( // Vocab utils // -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special) { - return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); + return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); } -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_model * model, const std::string & text, bool add_special, @@ -2758,7 +1520,7 @@ std::vector llama_tokenize( return result; } -std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); @@ -2774,7 +1536,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t return piece; } -std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { +std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); @@ -2794,15 +1556,15 @@ std::string llama_detokenize(llama_context * ctx, const std::vector // Chat template utils // -bool llama_chat_verify_template(const std::string & tmpl) { +bool common_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } -std::string llama_chat_apply_template(const struct llama_model * model, +std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, - const std::vector & msgs, + const std::vector & msgs, bool add_ass) { int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml @@ -2844,42 +1606,42 @@ std::string llama_chat_apply_template(const struct llama_model * model, return formatted_chat; } -std::string llama_chat_format_single(const struct llama_model * model, +std::string common_chat_format_single(const struct llama_model * model, const std::string & tmpl, - const std::vector & past_msg, - const llama_chat_msg & new_msg, + const std::vector & past_msg, + const common_chat_msg & new_msg, bool add_ass) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); - std::vector chat_new(past_msg); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { ss << "\n"; }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string llama_chat_format_example(const struct llama_model * model, +std::string common_chat_format_example(const struct llama_model * model, const std::string & tmpl) { - std::vector msgs = { + std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return llama_chat_apply_template(model, tmpl, msgs, true); + return common_chat_apply_template(model, tmpl, msgs, true); } // // KV cache utils // -void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { +void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", @@ -2902,7 +1664,7 @@ void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } -void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { +void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", @@ -2954,7 +1716,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) { +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; switch (embd_norm) { @@ -2988,7 +1750,7 @@ void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) } } -float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){ +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){ double sum = 0.0; double sum1 = 0.0; double sum2 = 0.0; @@ -3014,8 +1776,8 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) // Control vector utils // -static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { - llama_control_vector_data result = { -1, {} }; +static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) { + common_control_vector_data result = { -1, {} }; lm_ggml_context * ctx = nullptr; struct lm_gguf_init_params meta_lm_gguf_params = { @@ -3024,13 +1786,13 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr }; struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(load_info.fname.c_str(), meta_lm_gguf_params); if (!ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); return result; } int32_t n_tensors = lm_gguf_get_n_tensors(ctx_gguf); if (n_tensors == 0) { - fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); + LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); } for (int i = 0; i < n_tensors; i++) { @@ -3048,23 +1810,23 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr } } if (layer_idx < 0) { - fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } else if (layer_idx == 0) { - fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } struct lm_ggml_tensor * tensor = lm_ggml_get_tensor(ctx, name.c_str()); if (tensor->type != LM_GGML_TYPE_F32) { - fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } if (lm_ggml_n_dims(tensor) != 1) { - fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } @@ -3072,7 +1834,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr if (result.n_embd == -1) { result.n_embd = lm_ggml_nelements(tensor); } else if (lm_ggml_nelements(tensor) != result.n_embd) { - fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } @@ -3089,7 +1851,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr } if (result.n_embd == -1) { - fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); + LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); result.data.clear(); } @@ -3099,18 +1861,18 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr return result; } -llama_control_vector_data llama_control_vector_load(const std::vector & load_infos) { - llama_control_vector_data result = { -1, {} }; +common_control_vector_data common_control_vector_load(const std::vector & load_infos) { + common_control_vector_data result = { -1, {} }; for (const auto & info : load_infos) { - auto cur = llama_control_vector_load_one(info); + auto cur = common_control_vector_load_one(info); if (cur.n_embd == -1) { result.n_embd = -1; break; } if (result.n_embd != -1 && result.n_embd != cur.n_embd) { - fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); + LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); result.n_embd = -1; break; } @@ -3126,7 +1888,7 @@ llama_control_vector_data llama_control_vector_load(const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); @@ -3217,6 +1979,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cpu_has_sve: %s\n", lm_ggml_cpu_has_sve() ? "true" : "false"); fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false"); fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_riscv_v: %s\n", lm_ggml_cpu_has_riscv_v() ? "true" : "false"); fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false"); fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false"); @@ -3248,11 +2011,13 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); + fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); + fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); + fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); @@ -3260,10 +2025,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); @@ -3274,11 +2036,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); + for (const auto & logit_bias : sparams.logit_bias) { + fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); } fprintf(stream, "lora:\n"); @@ -3331,7 +2090,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); @@ -3340,12 +2098,13 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); - fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability); + fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold); + fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } diff --git a/cpp/common.h b/cpp/common.h index 18b027b..5a79c8c 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -4,18 +4,9 @@ #include "llama.h" -#include "sampling.h" - -#define LOG_NO_FILE_LINE_FUNCTION -#include "log.h" - -#include #include #include -#include -#include -#include -#include +#include #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -33,12 +24,12 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct llama_lora_adapter_info { +struct common_lora_adapter_info { std::string path; float scale; }; -struct llama_lora_adapter_container : llama_lora_adapter_info { +struct common_lora_adapter_container : common_lora_adapter_info { struct llama_lora_adapter * adapter; }; @@ -48,7 +39,7 @@ extern char const * LLAMA_COMMIT; extern char const * LLAMA_COMPILER; extern char const * LLAMA_BUILD_TARGET; -struct llama_control_vector_load_info; +struct common_control_vector_load_info; #define print_build_info() do { \ fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ @@ -65,26 +56,116 @@ extern char const *LLAMA_BUILD_TARGET; // CPU utils // +struct cpu_params { + int n_threads = -1; + bool cpumask[LM_GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. + bool mask_valid = false; // Default: any CPU + enum lm_ggml_sched_priority priority = LM_GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) + bool strict_cpu = false; // Use strict CPU placement + uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) +}; + int32_t cpu_get_num_physical_cores(); int32_t cpu_get_num_math(); // -// CLI argument parsing +// Common params // +enum llama_example { + LLAMA_EXAMPLE_COMMON, + LLAMA_EXAMPLE_SPECULATIVE, + LLAMA_EXAMPLE_MAIN, + LLAMA_EXAMPLE_INFILL, + LLAMA_EXAMPLE_EMBEDDING, + LLAMA_EXAMPLE_PERPLEXITY, + LLAMA_EXAMPLE_RETRIEVAL, + LLAMA_EXAMPLE_PASSKEY, + LLAMA_EXAMPLE_IMATRIX, + LLAMA_EXAMPLE_BENCH, + LLAMA_EXAMPLE_SERVER, + LLAMA_EXAMPLE_CVECTOR_GENERATOR, + LLAMA_EXAMPLE_EXPORT_LORA, + LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_LOOKUP, + LLAMA_EXAMPLE_PARALLEL, + + LLAMA_EXAMPLE_COUNT, +}; + +enum common_sampler_type { + COMMON_SAMPLER_TYPE_NONE = 0, + COMMON_SAMPLER_TYPE_DRY = 1, + COMMON_SAMPLER_TYPE_TOP_K = 2, + COMMON_SAMPLER_TYPE_TOP_P = 3, + COMMON_SAMPLER_TYPE_MIN_P = 4, + //COMMON_SAMPLER_TYPE_TFS_Z = 5, + COMMON_SAMPLER_TYPE_TYPICAL_P = 6, + COMMON_SAMPLER_TYPE_TEMPERATURE = 7, + COMMON_SAMPLER_TYPE_XTC = 8, + COMMON_SAMPLER_TYPE_INFILL = 9, +}; + // dimensionality reduction methods, used by cvector-generator enum dimre_method { DIMRE_METHOD_PCA, DIMRE_METHOD_MEAN, }; -struct gpt_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed +// sampler parameters +struct common_sampler_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler + + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; + bool no_perf = false; // disable performance metrics + + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY + + + std::vector samplers = { + COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TYPICAL_P, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_MIN_P, + COMMON_SAMPLER_TYPE_XTC, + COMMON_SAMPLER_TYPE_TEMPERATURE, + }; + + std::string grammar; // optional BNF-like grammar to constrain sampling + + std::vector logit_bias; // logit biases to apply + + // print the parameters into a string + std::string print() const; +}; - int32_t n_threads = cpu_get_num_math(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; +struct common_params { + bool vocab_only = false; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) @@ -111,6 +192,11 @@ struct gpt_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + struct cpu_params draft_cpuparams; + struct cpu_params draft_cpuparams_batch; + lm_ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; @@ -121,35 +207,34 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - // // sampling parameters - struct llama_sampling_params sparams; - - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding - std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_token = ""; // HF token - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file - std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::string logdir = ""; // directory in which to save YAML log files - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits - std::string rpc_servers = ""; // comma separated list of RPC servers + struct common_sampler_params sparams; + + std::string model = ""; // model path // NOLINT + std::string model_draft = ""; // draft model for speculative decoding // NOLINT + std::string model_alias = "unknown"; // model alias // NOLINT + std::string model_url = ""; // model url to download // NOLINT + std::string hf_token = ""; // HF token // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + std::string prompt = ""; // NOLINT + std::string prompt_file = ""; // store the external prompt file name // NOLINT + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT + std::string input_prefix = ""; // string to prefix user inputs with // NOLINT + std::string input_suffix = ""; // string to suffix user inputs with // NOLINT + std::string logdir = ""; // directory in which to save YAML log files // NOLINT + std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT + std::string logits_file = ""; // file for saving *all* logits // NOLINT + std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + std::vector lora_adapters; // lora adapter path with user defined scale - std::vector control_vectors; // control vector with user defined scale + std::vector control_vectors; // control vector with user defined scale int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector @@ -184,15 +269,15 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool no_perf = false; // disable performance metrics + bool ctx_shift = true; // context shift on inifinite text generation bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation - bool infill = false; // use infill mode bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run @@ -202,33 +287,37 @@ struct gpt_params { std::string cache_type_v = "f16"; // KV cache data type for the V // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector + std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) // embedding bool embedding = false; // get only sentence embedding - int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix - std::string embd_sep = "\n"; // separator of embendings + std::string embd_sep = "\n"; // separator of embeddings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting std::string hostname = "127.0.0.1"; - std::string public_path = ""; - std::string chat_template = ""; - std::string system_prompt = ""; + std::string public_path = ""; // NOLINT + std::string chat_template = ""; // NOLINT bool enable_chat_template = true; std::vector api_keys; - std::string ssl_file_key = ""; - std::string ssl_file_cert = ""; + std::string ssl_file_key = ""; // NOLINT + std::string ssl_file_cert = ""; // NOLINT - bool endpoint_slots = true; + // "advanced" endpoints are disabled by default for better security + bool webui = true; + bool endpoint_slots = false; + bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; bool log_json = false; @@ -276,23 +365,38 @@ struct gpt_params { bool spm_infill = false; // suffix/prefix/middle pattern for infill std::string lora_outfile = "ggml-lora-merged-f16.gguf"; + + // batched-bench params + bool batched_bench_output_jsonl = false; }; -void gpt_params_parse_from_env(gpt_params & params); -void gpt_params_handle_model_default(gpt_params & params); +// call once at the start of a program if it uses libcommon +// initializes the logging system and prints info about the build +void common_init(); -bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); -bool gpt_params_parse (int argc, char ** argv, gpt_params & params); -bool gpt_params_find_arg (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); -void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params); +std::string common_params_get_system_info(const common_params & params); -std::string gpt_params_get_system_info(const gpt_params & params); +bool parse_cpu_range(const std::string & range, bool(&boolmask)[LM_GGML_MAX_N_THREADS]); +bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[LM_GGML_MAX_N_THREADS]); +void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr); +bool set_process_priority(enum lm_ggml_sched_priority prio); // // String utils // -std::vector string_split(std::string input, char separator); +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +#endif + +LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) +std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); @@ -301,6 +405,7 @@ void string_replace_all(std::string & s, const std::string & search, const std:: template static std::vector string_split(const std::string & str, char delim) { + static_assert(!std::is_same::value, "Please use the specialized version for std::string"); std::vector values; std::istringstream str_stream(str); std::string token; @@ -313,9 +418,30 @@ static std::vector string_split(const std::string & str, char delim) { return values; } +template<> +std::vector string_split(const std::string & input, char separator) +{ + std::vector parts; + size_t begin_pos = 0; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(begin_pos, separator_pos - begin_pos); + parts.emplace_back(part); + begin_pos = separator_pos + 1; + separator_pos = input.find(separator, begin_pos); + } + parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + return parts; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); +std::string string_from(bool value); +std::string string_from(const std::vector & values); +std::string string_from(const struct llama_context * ctx, const std::vector & tokens); +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); + // // Filesystem utils // @@ -330,28 +456,29 @@ std::string fs_get_cache_file(const std::string & filename); // Model utils // -struct llama_init_result { +struct common_init_result { struct llama_model * model = nullptr; struct llama_context * context = nullptr; - std::vector lora_adapters; + std::vector lora_adapters; }; -struct llama_init_result llama_init_from_gpt_params(gpt_params & params); +struct common_init_result common_init_from_params(common_params & params); -struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +struct llama_model_params common_model_params_to_llama (const common_params & params); +struct llama_context_params common_context_params_to_llama(const common_params & params); +struct lm_ggml_threadpool_params lm_ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * common_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * common_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); // clear LoRA adapters from context, then apply new list of adapters -void llama_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); // Batch utils -void llama_batch_clear(struct llama_batch & batch); +void common_batch_clear(struct llama_batch & batch); -void llama_batch_add( +void common_batch_add( struct llama_batch & batch, llama_token id, llama_pos pos, @@ -364,13 +491,13 @@ void llama_batch_add( // tokenizes a string into a vector of tokens // should work similar to Python's `tokenizer.encode` -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special = false); -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_model * model, const std::string & text, bool add_special, @@ -378,7 +505,7 @@ std::vector llama_tokenize( // tokenizes a token into a piece, optionally renders special/control tokens // should work similar to Python's `tokenizer.id_to_piece` -std::string llama_token_to_piece( +std::string common_token_to_piece( const struct llama_context * ctx, llama_token token, bool special = true); @@ -386,7 +513,7 @@ std::string llama_token_to_piece( // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // optionally renders special/control tokens -std::string llama_detokenize( +std::string common_detokenize( llama_context * ctx, const std::vector & tokens, bool special = true); @@ -396,31 +523,31 @@ std::string llama_detokenize( // // same with llama_chat_message, but uses std::string -struct llama_chat_msg { +struct common_chat_msg { std::string role; std::string content; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool llama_chat_verify_template(const std::string & tmpl); +bool common_chat_verify_template(const std::string & tmpl); // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error -std::string llama_chat_apply_template(const struct llama_model * model, +std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, - const std::vector & chat, + const std::vector & chat, bool add_ass); // Format single message, while taking into account the position of that message in chat history -std::string llama_chat_format_single(const struct llama_model * model, +std::string common_chat_format_single(const struct llama_model * model, const std::string & tmpl, - const std::vector & past_msg, - const llama_chat_msg & new_msg, + const std::vector & past_msg, + const common_chat_msg & new_msg, bool add_ass); // Returns an example of formatted chat -std::string llama_chat_format_example(const struct llama_model * model, +std::string common_chat_format_example(const struct llama_model * model, const std::string & tmpl); // @@ -428,31 +555,31 @@ std::string llama_chat_format_example(const struct llama_model * model, // // Dump the KV cache view with the number of sequences per cell. -void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); +void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); // // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); -float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); // // Control vector utils // -struct llama_control_vector_data { +struct common_control_vector_data { int n_embd; // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd std::vector data; }; -struct llama_control_vector_load_info { +struct common_control_vector_load_info { float strength; std::string fname; @@ -460,7 +587,7 @@ struct llama_control_vector_load_info { // Load control vectors, scale each by strength, and add them together. // On error, returns {-1, empty} -llama_control_vector_data llama_control_vector_load(const std::vector & load_infos); +common_control_vector_data common_control_vector_load(const std::vector & load_infos); // // Split utils @@ -479,5 +606,5 @@ void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); void yaml_dump_non_result_info( - FILE * stream, const gpt_params & params, const llama_context * lctx, + FILE * stream, const common_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); diff --git a/cpp/ggml-aarch64.c b/cpp/ggml-aarch64.c index 9c507d4..d5f9419 100644 --- a/cpp/ggml-aarch64.c +++ b/cpp/ggml-aarch64.c @@ -1,9 +1,13 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + #define LM_GGML_COMMON_IMPL_C #include "ggml-common.h" #include "ggml-quants.h" #include "ggml-impl.h" +#include "ggml-cpu-impl.h" #include #include @@ -36,6 +40,152 @@ // from bias offset form to pure sign form (this saves subtract // operations durin unpacking) // +#if defined(__AVX__) +#if defined(__F16C__) +#if defined(__AVX512F__) +#define LM_GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x)))) +#define LM_GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x)) +#endif +// the _mm256_cvt intrinsics require F16C +#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) +#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) +#else +#if defined(__AVX512F__) +static inline __m512 __avx512_f32cx8x2_load(lm_ggml_fp16_t *x, lm_ggml_fp16_t *y) { + float tmp[16]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + for (int i = 0; i < 8; i++) { + tmp[i + 8] = LM_GGML_FP16_TO_FP32(y[i]); + } + + return _mm512_loadu_ps(tmp); +} +static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) { + float tmp[16]; + uint16_t tmphalf[8]; + _mm_storeu_si128((__m128i*)tmphalf, x); + + for (int i = 0; i < 4; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 4] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 8] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 12] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm512_loadu_ps(tmp); +} +#endif +static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_repeat_f32cx8_load(lm_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 4; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + tmp[i + 4] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_rearranged_f32cx8_load(lm_ggml_fp16_t *x, __m128i arrangeMask) { + uint16_t tmphalf[8]; + float tmp[8]; + + _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm256_loadu_ps(tmp); +} + +#define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) +#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) +#if defined(__AVX512F__) +#define LM_GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y) +#define LM_GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x) +#endif +#endif +#endif + + +#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX512F__) +// add int16_t pairwise and return as 512 bit int vector +static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { + const __m512i ones = _mm512_set1_epi16(1); + return _mm512_madd_epi16(ones, x); +} + +static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m512i zero = _mm512_setzero_si512(); + return _mm512_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m512i dot = _mm512_maddubs_epi16(ax, sy); + return sum_i16_pairs_int_32x16(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as 512 bit int vector +static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { + const __m512i zero = _mm512_setzero_si512(); + // Get absolute values of x vectors + const __m512i ax = _mm512_abs_epi8(x); + // Sign the values of the y vectors + __mmask64 blt0 = _mm512_movepi8_mask(x); + const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); + return mul_sum_us8_pairs_int32x16(ax, sy); +} +#endif + +// add int16_t pairwise and return as 256 bit int vector +static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + return _mm256_madd_epi16(ones, x); +} + +static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_int32x8(dot); +#endif +} + +// Integer variant of the function defined in ggml-quants.c +// multiply int8_t, add results pairwise twice and return as 256 bit int vector +static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbssd_epi32(zero, x, y); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_int32x8(ax, sy); +#endif +} +#endif + static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x4 out; @@ -255,6 +405,103 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); } } +#elif defined(__AVX2__) || defined(__AVX__) + float id[4]; + __m256 srcv[4][4]; + __m256 idvec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Divided by 127.f to mirror results in quantize_row_q8_0 + const float d = maxScalar / 127.f; + id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; + + // Store the scale for the individual block + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + + // Store the values in blocks of eight values - Aim is to use these later for block interleaving + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + idvec[row_iter] = _mm256_set1_ps(id[row_iter]); + } + + // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved + for (int j = 0; j < 4; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); +#endif + } + } #else // scalar const int blck_size_interleave = 8; @@ -337,33 +584,18 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds } size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); } size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); } size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -385,73 +617,67 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) - if (lm_ggml_sve_cnt_b == QK8_0) { - LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (lm_ggml_cpu_has_neon()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); + return; } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - LM_GGML_ASSERT(!(lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v31.16b, #0x4\n" - "movi v30.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "movi v29.16b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ldr q28, [%x[b_ptr], #0x0]\n" - "ldr q27, [x22, #0x0]\n" - "movi v26.4s, #0x0\n" - "sub x20, x22, #0x2\n" - "ldr q25, [x22, #0x10]\n" - "ldr q24, [%x[b_ptr], #0x10]\n" - "sub x21, x21, #0x1\n" - "add x22, x22, #0x22\n" - "ldr q23, [%x[b_ptr], #0x20]\n" - "ldr q22, [%x[b_ptr], #0x30]\n" - "ld1r { v21.8h }, [x20]\n" - "ldr q20, [%x[b_ptr], #-0x8]\n" - "sshl v16.16b, v28.16b, v31.16b\n" - "and v28.16b, v28.16b, v30.16b\n" - "sshl v19.16b, v24.16b, v31.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "sshl v18.16b, v23.16b, v31.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" - "sshl v17.16b, v22.16b, v31.16b\n" - "and v22.16b, v22.16b, v30.16b\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v16.4s, v20.4h\n" - ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" - "fmul v16.4s, v16.4s, v21.4s\n" - ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" - ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" - ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" - ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" - ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" - ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v29.4s, v26.4s, v16.4s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q29, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" - ); -#else +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) float sumf[4]; int sumi; @@ -475,7 +701,6 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void } for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } -#endif } void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -497,79 +722,72 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) - if (lm_ggml_sve_cnt_b == QK8_0) { - LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); + return; } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v2.16b, #0x4\n" - "movi v1.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x23, %x[a_ptr], #0x2\n" - "movi v0.16b, #0x0\n" - "mov x22, %x[nb]\n" - "2:" // Block loop - "ldr q31, [%x[b_ptr], #0x0]\n" - "ldr q30, [%x[b_ptr], #0x10]\n" - "mov x21, x23\n" - "movi v29.4s, #0x0\n" - "ldr q28, [%x[b_ptr], #0x20]\n" - "ldr q27, [%x[b_ptr], #0x30]\n" - "movi v26.4s, #0x0\n" - "sub x20, x23, #0x2\n" - "ld1r { v25.8h }, [x20]\n" - "ldr q24, [%x[b_ptr], #-0x8]\n" - "sub x22, x22, #0x1\n" - "add x23, x23, #0x22\n" - "ld1r { v23.2d }, [x21], #0x8\n" - "sshl v22.16b, v31.16b, v2.16b\n" - "sshl v16.16b, v30.16b, v2.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "sshl v20.16b, v28.16b, v2.16b\n" - "sshl v19.16b, v27.16b, v2.16b\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "ld1r { v17.2d }, [x21], #0x8\n" - "and v31.16b, v31.16b, v1.16b\n" - "and v30.16b, v30.16b, v1.16b\n" - ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" - ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" - "and v28.16b, v28.16b, v1.16b\n" - "and v27.16b, v27.16b, v1.16b\n" - "fcvtl v25.4s, v25.4h\n" - "fcvtl v16.4s, v24.4h\n" - ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" - ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" - "fmul v16.4s, v16.4s, v25.4s\n" - ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" - ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" - ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" - ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" - "addp v29.4s, v29.4s, v26.4s\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v0.4s, v29.4s, v16.4s\n" - "cbnz x22, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q0, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) float sumf[4]; int sumi; @@ -593,7 +811,6 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void } for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } -#endif } void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -615,8 +832,9 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (lm_ggml_sve_cnt_b == QK8_0) { +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) + if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; @@ -681,49 +899,191 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void ); return; } - else if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); +#endif // #if defined(__ARM_FEATURE_SVE) +#elif defined(__AVX2__) + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // Permute mask used for easier vector processing at later stages + const __m256i m4b = _mm256_set1_epi8(0x0F); + + int64_t b_nb = n / QK4_0; + + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; + + // Process Q8_0 blocks one by one + for (int64_t y = 0; y < nr; y++) { + + // Pointers to LHS blocks of block_q8_0 format + const block_q8_0 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < nc / 8; x++) { + + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulator + __m256 acc_row = _mm256_setzero_ps(); + + for (int64_t b = 0; b < nb; b++) { + // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) + const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) + const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + + const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) + const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) + const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) + const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) + + // Load the scale values for the 8 blocks interleaved in block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + + // Load and convert to FP32 scale from block_q8_0 + const __m256 row_scale_f32 = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(a_ptr[b].d)); + + // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector + __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); + __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); + + lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) + lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) + + __m256i iacc = _mm256_setzero_si256(); + + // Dot product done within 32 bit lanes and accumulated in the same vector + // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) + // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) + // ........................................................................... + // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + + // Accumulated values multipled with appropriate scales + acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); + } + + // Accumulated output values permuted so as to be stored in appropriate order post accumulation + acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); + _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); + } } - else if (lm_ggml_cpu_has_neon()) { - LM_GGML_ASSERT(((lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); + return; +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); + + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + // vector version needs Zvfhmin extension + const float a_scale = LM_GGML_FP16_TO_FP32(a_ptr[l].d); + const float b_scales[8] = { + LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); + } + return; } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - LM_GGML_ASSERT(lm_ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[8]; - int sumi; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + { + float sumf[8]; + int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); } - sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); } } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } -#endif } void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -746,505 +1106,500 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_sve_cnt_b == QK8_0) { - LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - LM_GGML_ASSERT(!(lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" - ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x40]\n" - ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x50]\n" - ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x60]\n" - ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x0]\n" - ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" - ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" - ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" - "fmul v9.4s, v17.4s, v29.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v25.4s, v20.4s, v9.4s\n" - "ldr q9, [x21, #0x10]\n" - "fmul v20.4s, v17.4s, v29.s[1]\n" - "fmla v7.4s, v10.4s, v20.4s\n" - "ldr d20, [x21, #-0x8]\n" - "fmul v10.4s, v17.4s, v29.s[2]\n" - "fmul v29.4s, v17.4s, v29.s[3]\n" - "fcvtl v20.4s, v20.4h\n" - "fmla v0.4s, v26.4s, v10.4s\n" - "movi v26.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v4.4s, v2.4s, v29.4s\n" - "movi v2.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" - "ldr q12, [x21, #0x20]\n" - "fmul v24.4s, v17.4s, v20.s[0]\n" - ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x30]\n" - "fmul v31.4s, v17.4s, v20.s[1]\n" - ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" - ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" - ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" - ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x40]\n" - "fmul v6.4s, v17.4s, v20.s[2]\n" - "fmul v20.4s, v17.4s, v20.s[3]\n" - ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x50]\n" - ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" - ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" - ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" - ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x60]\n" - ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" - "ldr q17, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" - ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" - ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" - ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" - ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" - ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" - ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" - ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v5.4s, v26.4s, v24.4s\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v21.4s, v10.4s, v31.4s\n" - "fmla v8.4s, v2.4s, v6.4s\n" - "fmla v1.4s, v29.4s, v20.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q16, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q1, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q7, [x24, #0x0]\n" - "ldr q5, [x25, #0x0]\n" - "movi v9.16b, #0x4\n" - "movi v4.4s, #0x0\n" - "ldr q3, [x24, #0x10]\n" - "ldr q2, [x25, #0x10]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q13, [x24, #0x20]\n" - "ldr q31, [x25, #0x20]\n" - "movi v30.4s, #0x0\n" - "movi v29.16b, #0xf0\n" - "ldr q28, [x24, #0x30]\n" - "ldr q27, [x25, #0x30]\n" - "sshl v20.16b, v7.16b, v9.16b\n" - "sub x20, x24, #0x8\n" - "ldr q26, [x25, #0x40]\n" - "ldr q25, [x25, #0x50]\n" - "sshl v17.16b, v3.16b, v9.16b\n" - "and v7.16b, v7.16b, v29.16b\n" - "ldr q24, [x25, #0x60]\n" - "ldr q16, [x25, #0x70]\n" - "sshl v22.16b, v13.16b, v9.16b\n" - "and v3.16b, v3.16b, v29.16b\n" - "ldr d21, [x20, #0x0]\n" - "ldr d12, [x25, #-0x8]\n" - ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" - ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" - ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" - ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" - "sshl v9.16b, v28.16b, v9.16b\n" - "subs x21, x21, #0x1\n" - "and v13.16b, v13.16b, v29.16b\n" - "and v28.16b, v28.16b, v29.16b\n" - "add x25, x25, #0x88\n" - "add x24, x24, #0x48\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v12.4s, v12.4h\n" - ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" - ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" - ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" - ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" - "fmul v11.4s, v21.4s, v12.s[0]\n" - "fmul v23.4s, v21.4s, v12.s[1]\n" - "fmul v17.4s, v21.4s, v12.s[2]\n" - ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" - "fmul v6.4s, v21.4s, v12.s[3]\n" - ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" - ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" - ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" - ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" - ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" - ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" - ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" - ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" - ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" - ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" - ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" - ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" - ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" - ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" - ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" - ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" - ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" - ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" - ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "scvtf v1.4s, v1.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v15.4s, v4.4s, v11.4s\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "fmla v19.4s, v1.4s, v23.4s\n" - "fmla v18.4s, v0.4s, v17.4s\n" - "fmla v14.4s, v30.4s, v6.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q14, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#else - float sumf[4][4]; - int sumi; +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (lm_ggml_cpu_has_neon()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); } - sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); } } } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } } } } -#endif } void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -1267,413 +1622,406 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_sve_cnt_b == QK8_0) { - LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v22.16b, #0x0\n" - "movi v23.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v6.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "3:" // Block loop - "ldr q21, [x28, #0x0]\n" - "ldr q16, [x28, #0x10]\n" - "movi v1.16b, #0x4\n" - "movi v19.4s, #0x0\n" - "ldr q27, [x25, #0x0]\n" - "ldr q15, [x25, #0x10]\n" - "movi v26.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "ldr q29, [x28, #0x20]\n" - "ldr q3, [x28, #0x30]\n" - "movi v17.4s, #0x0\n" - "movi v0.16b, #0xf0\n" - "ldr d20, [x25, #-0x8]\n" - "ldr d9, [x23, #-0x8]\n" - "sshl v8.16b, v21.16b, v1.16b\n" - "sshl v31.16b, v16.16b, v1.16b\n" - "and v21.16b, v21.16b, v0.16b\n" - "and v16.16b, v16.16b, v0.16b\n" - "sub x20, x28, #0x8\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" - ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" - "ldr q27, [x25, #0x20]\n" - ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" - ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" - "sshl v15.16b, v29.16b, v1.16b\n" - "sshl v1.16b, v3.16b, v1.16b\n" - "and v29.16b, v29.16b, v0.16b\n" - "and v3.16b, v3.16b, v0.16b\n" - "ldr q0, [x25, #0x30]\n" - "fcvtl v20.4s, v20.4h\n" - ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" - "fcvtl v9.4s, v9.4h\n" - ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" - "ldr q27, [x25, #0x40]\n" - ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" - ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" - "ldr q0, [x25, #0x50]\n" - ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" - ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" - "ldr q27, [x25, #0x60]\n" - ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" - ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" - "ldr q0, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - "ldr d27, [x20, #0x0]\n" - ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" - ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" - "fcvtl v27.4s, v27.4h\n" - "uzp1 v0.2d, v19.2d, v26.2d\n" - "uzp2 v26.2d, v19.2d, v26.2d\n" - "fmul v19.4s, v27.4s, v20.s[0]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v2.4s, v0.4s, v19.4s\n" - "ldr q19, [x23, #0x0]\n" - "uzp1 v0.2d, v18.2d, v17.2d\n" - "uzp2 v18.2d, v18.2d, v17.2d\n" - "fmul v17.4s, v27.4s, v20.s[1]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v10.4s, v26.4s, v17.4s\n" - "ldr q17, [x23, #0x10]\n" - "fmul v26.4s, v27.4s, v20.s[2]\n" - "fmul v20.4s, v27.4s, v20.s[3]\n" - "fmla v12.4s, v0.4s, v26.4s\n" - "ldr d0, [x22, #-0x8]\n" - "ldr d26, [x21, #-0x8]\n" - "fcvtl v0.4s, v0.4h\n" - "fmla v28.4s, v18.4s, v20.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x23, #0x20]\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x23, #0x40]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q19, [x23, #0x60]\n" - ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - "uzp1 v19.2d, v20.2d, v18.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp2 v20.2d, v20.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v9.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v11.4s, v19.4s, v18.4s\n" - "ldr q18, [x22, #0x0]\n" - "fmul v19.4s, v27.4s, v9.s[1]\n" - "fmla v13.4s, v20.4s, v19.4s\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" - ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" - "ldr q17, [x23, #0x30]\n" - ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" - "ldr q17, [x23, #0x50]\n" - ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" - "ldr q17, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v9.s[2]\n" - "fmul v9.4s, v27.4s, v9.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v22.4s, v17.4s, v19.4s\n" - "ldr q17, [x22, #0x10]\n" - "movi v19.4s, #0x0\n" - ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" - "fmla v23.4s, v20.4s, v9.4s\n" - "movi v20.4s, #0x0\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" - "ldr q18, [x22, #0x20]\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" - ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" - "ldr q18, [x22, #0x40]\n" - ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" - ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" - "ldr q18, [x22, #0x60]\n" - ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" - ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" - "ldr q17, [x22, #0x30]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" - "ldr q17, [x22, #0x50]\n" - ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" - "ldr q17, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v0.s[0]\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v25.4s, v17.4s, v19.4s\n" - "ldr q19, [x21, #0x0]\n" - "fmul v17.4s, v27.4s, v0.s[1]\n" - "fmla v5.4s, v20.4s, v17.4s\n" - "ldr q17, [x21, #0x10]\n" - "uzp1 v20.2d, v9.2d, v18.2d\n" - "uzp2 v9.2d, v9.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v0.s[2]\n" - "fmul v0.4s, v27.4s, v0.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v7.4s, v20.4s, v18.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x21, #0x20]\n" - "fmla v4.4s, v9.4s, v0.4s\n" - "movi v9.4s, #0x0\n" - "movi v0.4s, #0x0\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - "fmul v8.4s, v27.4s, v26.s[0]\n" - ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" - "ldr q17, [x21, #0x30]\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - "fmul v31.4s, v27.4s, v26.s[1]\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - "fmul v15.4s, v27.4s, v26.s[2]\n" - "fmul v27.4s, v27.4s, v26.s[3]\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - "ldr q1, [x21, #0x50]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q26, [x21, #0x60]\n" - ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" - ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" - "ldr q21, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" - ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" - ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" - ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" - "uzp1 v29.2d, v20.2d, v18.2d\n" - "uzp2 v21.2d, v20.2d, v18.2d\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "uzp1 v18.2d, v9.2d, v0.2d\n" - "uzp2 v16.2d, v9.2d, v0.2d\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "fmla v6.4s, v29.4s, v8.4s\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v21.4s, v31.4s\n" - "fmla v24.4s, v18.4s, v15.4s\n" - "fmla v14.4s, v16.4s, v27.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q6, [x24, #0x0]\n" - "ldr q5, [x24, #0x10]\n" - "movi v17.16b, #0x4\n" - "movi v8.4s, #0x0\n" - "ldr q4, [x25, #0x0]\n" - "ldr q13, [x25, #0x10]\n" - "movi v27.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q31, [x24, #0x20]\n" - "ldr q14, [x24, #0x30]\n" - "movi v29.4s, #0x0\n" - "movi v22.16b, #0xf0\n" - "ldr q11, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "sshl v21.16b, v6.16b, v17.16b\n" - "sshl v16.16b, v5.16b, v17.16b\n" - "ldr q20, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "and v6.16b, v6.16b, v22.16b\n" - "and v5.16b, v5.16b, v22.16b\n" - "ldr q25, [x25, #0x60]\n" - "ldr q3, [x25, #0x70]\n" - "sshl v19.16b, v31.16b, v17.16b\n" - "sshl v18.16b, v14.16b, v17.16b\n" - "ldr d17, [x25, #-0x8]\n" - ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" - ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" - "and v31.16b, v31.16b, v22.16b\n" - ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" - ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" - "and v14.16b, v14.16b, v22.16b\n" - "sub x20, x24, #0x8\n" - "ldr d16, [x20, #0x0]\n" - "subs x21, x21, #0x1\n" - "add x25, x25, #0x88\n" - "fcvtl v17.4s, v17.4h\n" - "add x24, x24, #0x48\n" - ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" - ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" - ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" - ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" - "fcvtl v16.4s, v16.4h\n" - ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" - ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" - "fmul v23.4s, v16.4s, v17.s[0]\n" - "fmul v21.4s, v16.4s, v17.s[1]\n" - "fmul v1.4s, v16.4s, v17.s[2]\n" - "fmul v20.4s, v16.4s, v17.s[3]\n" - ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" - ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" - ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" - ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" - ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" - ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" - "uzp1 v19.2d, v8.2d, v27.2d\n" - "uzp2 v18.2d, v8.2d, v27.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v0.2d, v29.2d\n" - "uzp2 v16.2d, v0.2d, v29.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v2.4s, v19.4s, v23.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v10.4s, v18.4s, v21.4s\n" - "fmla v12.4s, v17.4s, v1.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q28, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) float sumf[4][4]; int sumi; @@ -1693,7 +2041,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; } sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); } @@ -1706,7 +2054,6 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void } } } -#endif } void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { @@ -1729,8 +2076,9 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (lm_ggml_sve_cnt_b == QK8_0) { +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; @@ -2140,25 +2488,960 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void ); return; } - else if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); +#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) +#elif defined(__AVX2__) || defined(__AVX512F__) + { + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; + int64_t b_nb = n / QK4_0; + int64_t y = 0; + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); + int64_t xstart = 0; + int anr = nr - nr%16; // Used to align nr with boundary of 16 + #ifdef __AVX512F__ + int anc = nc - nc%16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length + __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Process LHS in pairs of rows + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); + const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); + const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } + #endif // __AVX512F__ + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + + for (; y < anr / 4; y += 4) { + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Process LHS in groups of four + for (int rp = 0; rp < 4; rp++) { + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + return; } - else if (lm_ggml_cpu_has_neon()) { - LM_GGML_ASSERT(((lm_ggml_cpu_has_sve() && (lm_ggml_sve_cnt_b == QK8_0)) || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + // vector version needs Zvfhmin extension + const float a_scales[4] = { + LM_GGML_FP16_TO_FP32(a_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - LM_GGML_ASSERT(lm_ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) float sumf[4][8]; int sumi; @@ -2191,5 +3474,4 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void } } } -#endif } diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index 3f27d17..6843f5c 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -14,7 +14,7 @@ //#define LM_GGML_ALLOCATOR_DEBUG -//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__) +//#define AT_PRINTF(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) @@ -89,7 +89,7 @@ void lm_ggml_tallocr_alloc(struct lm_ggml_tallocr * talloc, struct lm_ggml_tenso size = LM_GGML_PAD(size, talloc->alignment); if (talloc->offset + size > lm_ggml_backend_buffer_get_size(talloc->buffer)) { - fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", + LM_GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", __func__, tensor->name, size, lm_ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset); LM_GGML_ABORT("not enough space in the buffer"); } @@ -172,7 +172,7 @@ static size_t lm_ggml_dyn_tallocr_alloc(struct lm_ggml_dyn_tallocr * alloc, size best_fit_block = alloc->n_free_blocks - 1; } else { // this should never happen - fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", + LM_GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); LM_GGML_ABORT("not enough space in the buffer"); } @@ -209,16 +209,16 @@ static size_t lm_ggml_dyn_tallocr_alloc(struct lm_ggml_dyn_tallocr * alloc, size } } } - fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + LM_GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i].tensor) { - fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, + LM_GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, alloc->allocated_tensors[i].offset, alloc->allocated_tensors[i].offset + lm_ggml_nbytes(alloc->allocated_tensors[i].tensor), lm_ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0); } } - fprintf(stderr, "\n"); + LM_GGML_LOG_DEBUG("\n"); } #endif @@ -294,6 +294,12 @@ static void lm_ggml_dyn_tallocr_reset(struct lm_ggml_dyn_tallocr * alloc) { alloc->free_blocks[0].offset = 0; alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows alloc->max_size = 0; + +#ifdef LM_GGML_ALLOCATOR_DEBUG + for (int i = 0; i < 1024; i++) { + alloc->allocated_tensors[i].tensor = NULL; + } +#endif } static struct lm_ggml_dyn_tallocr * lm_ggml_dyn_tallocr_new(size_t alignment) { @@ -342,7 +348,6 @@ struct tensor_alloc { }; struct leaf_alloc { - int buffer_id; struct tensor_alloc leaf; }; @@ -734,7 +739,6 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * for (int i = 0; i < graph->n_leafs; i++) { struct lm_ggml_tensor * leaf = graph->leafs[i]; struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, leaf); - galloc->leaf_allocs[i].buffer_id = hn->buffer_id; if (leaf->view_src || leaf->data) { galloc->leaf_allocs[i].leaf.buffer_id = -1; galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; @@ -762,13 +766,13 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views if (new_size > cur_size || galloc->buffers[i] == NULL) { #ifndef NDEBUG - fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + LM_GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif lm_ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = lm_ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); if (galloc->buffers[i] == NULL) { - fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), new_size); + LM_GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), new_size); return false; } lm_ggml_backend_buffer_set_usage(galloc->buffers[i], LM_GGML_BACKEND_BUFFER_USAGE_COMPUTE); @@ -819,14 +823,14 @@ static bool lm_ggml_gallocr_node_needs_realloc(lm_ggml_gallocr_t galloc, struct static bool lm_ggml_gallocr_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * graph) { if (galloc->n_nodes != graph->n_nodes) { #ifndef NDEBUG - fprintf(stderr, "%s: graph has different number of nodes\n", __func__); + LM_GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__); #endif return true; } if (galloc->n_leafs != graph->n_leafs) { #ifndef NDEBUG - fprintf(stderr, "%s: graph has different number of leafs\n", __func__); + LM_GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__); #endif return true; } @@ -837,7 +841,7 @@ static bool lm_ggml_gallocr_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_gg if (!lm_ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) { #ifndef NDEBUG - fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name); + LM_GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name); #endif return true; } @@ -849,7 +853,7 @@ static bool lm_ggml_gallocr_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_gg } if (!lm_ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) { #ifndef NDEBUG - fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); + LM_GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); #endif return true; } @@ -863,14 +867,14 @@ bool lm_ggml_gallocr_alloc_graph(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph if (lm_ggml_gallocr_needs_realloc(galloc, graph)) { if (galloc->n_buffers == 1) { #ifndef NDEBUG - fprintf(stderr, "%s: reallocating buffers automatically\n", __func__); + LM_GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__); #endif if (!lm_ggml_gallocr_reserve(galloc, graph)) { return false; } } else { #ifndef NDEBUG - fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__); + LM_GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__); #endif return false; } @@ -934,7 +938,7 @@ static bool alloc_tensor_range(struct lm_ggml_context * ctx, lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(buft, size); if (buffer == NULL) { #ifndef NDEBUG - fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(buft), size); + LM_GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(buft), size); #endif for (size_t i = 0; i < *n_buffers; i++) { lm_ggml_backend_buffer_free((*buffers)[i]); @@ -984,7 +988,7 @@ lm_ggml_backend_buffer_t lm_ggml_backend_alloc_ctx_tensors_from_buft(struct lm_g } if (this_size > max_size) { - fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n", + LM_GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n", __func__, t->name, lm_ggml_backend_buft_name(buft), this_size, max_size); @@ -1016,7 +1020,7 @@ lm_ggml_backend_buffer_t lm_ggml_backend_alloc_ctx_tensors_from_buft(struct lm_g if (n_buffers == 0) { #ifndef NDEBUG - fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__); + LM_GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__); #endif return NULL; } diff --git a/cpp/ggml-alloc.h b/cpp/ggml-alloc.h index 10905a8..6055571 100644 --- a/cpp/ggml-alloc.h +++ b/cpp/ggml-alloc.h @@ -7,8 +7,8 @@ extern "C" { #endif typedef struct lm_ggml_backend_buffer_type * lm_ggml_backend_buffer_type_t; -typedef struct lm_ggml_backend_buffer * lm_ggml_backend_buffer_t; -typedef struct lm_ggml_backend * lm_ggml_backend_t; +typedef struct lm_ggml_backend_buffer * lm_ggml_backend_buffer_t; +typedef struct lm_ggml_backend * lm_ggml_backend_t; // Tensor allocator struct lm_ggml_tallocr { @@ -24,7 +24,7 @@ LM_GGML_API void lm_ggml_tallocr_alloc(struct lm_ggml_tallocr * t // Graph allocator /* Example usage: - lm_ggml_gallocr_t galloc = lm_ggml_gallocr_new(lm_ggml_bacckend_cpu_buffer_type()); + lm_ggml_gallocr_t galloc = lm_ggml_gallocr_new(lm_ggml_backend_cpu_buffer_type()); // optional: create a worst-case graph and reserve the buffers to avoid reallocations lm_ggml_gallocr_reserve(galloc, build_graph(max_batch)); diff --git a/cpp/ggml-backend-impl.h b/cpp/ggml-backend-impl.h index 31eba9d..afcd3a9 100644 --- a/cpp/ggml-backend-impl.h +++ b/cpp/ggml-backend-impl.h @@ -9,144 +9,207 @@ extern "C" { #endif // - // Backend buffer + // Backend buffer type // - // buffer type - typedef void * lm_ggml_backend_buffer_type_context_t; - struct lm_ggml_backend_buffer_type_i { - const char * (*LM_GGML_CALL get_name) (lm_ggml_backend_buffer_type_t buft); + const char * (*get_name) (lm_ggml_backend_buffer_type_t buft); // allocate a buffer of this type - lm_ggml_backend_buffer_t (*LM_GGML_CALL alloc_buffer) (lm_ggml_backend_buffer_type_t buft, size_t size); + lm_ggml_backend_buffer_t (*alloc_buffer) (lm_ggml_backend_buffer_type_t buft, size_t size); // tensor alignment - size_t (*LM_GGML_CALL get_alignment) (lm_ggml_backend_buffer_type_t buft); - // max buffer size that can be allocated - size_t (*LM_GGML_CALL get_max_size) (lm_ggml_backend_buffer_type_t buft); - // data size needed to allocate the tensor, including padding - size_t (*LM_GGML_CALL get_alloc_size) (lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor); - // check if tensor data is in host memory - bool (*LM_GGML_CALL is_host) (lm_ggml_backend_buffer_type_t buft); + size_t (*get_alignment) (lm_ggml_backend_buffer_type_t buft); + // (optional) max buffer size that can be allocated (defaults to SIZE_MAX) + size_t (*get_max_size) (lm_ggml_backend_buffer_type_t buft); + // (optional) data size needed to allocate the tensor, including padding (defaults to lm_ggml_nbytes) + size_t (*get_alloc_size)(lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor); + // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) + bool (*is_host) (lm_ggml_backend_buffer_type_t buft); }; struct lm_ggml_backend_buffer_type { struct lm_ggml_backend_buffer_type_i iface; - lm_ggml_backend_buffer_type_context_t context; + lm_ggml_backend_dev_t device; + void * context; }; - // buffer - typedef void * lm_ggml_backend_buffer_context_t; + // + // Backend buffer + // struct lm_ggml_backend_buffer_i { - const char * (*LM_GGML_CALL get_name) (lm_ggml_backend_buffer_t buffer); - void (*LM_GGML_CALL free_buffer)(lm_ggml_backend_buffer_t buffer); - void * (*LM_GGML_CALL get_base) (lm_ggml_backend_buffer_t buffer); - void (*LM_GGML_CALL init_tensor)(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); - void (*LM_GGML_CALL set_tensor) (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*LM_GGML_CALL get_tensor) (lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*LM_GGML_CALL cpy_tensor) (lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); // dst is in the buffer, src may be in any buffer - void (*LM_GGML_CALL clear) (lm_ggml_backend_buffer_t buffer, uint8_t value); - void (*LM_GGML_CALL reset) (lm_ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras + // (optional) free the buffer + void (*free_buffer) (lm_ggml_backend_buffer_t buffer); + // base address of the buffer + void * (*get_base) (lm_ggml_backend_buffer_t buffer); + // (optional) initialize a tensor in the buffer (eg. add tensor extras) + void (*init_tensor) (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + // tensor data access + void (*memset_tensor)(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + void (*set_tensor) (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor) (lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) + bool (*cpy_tensor) (lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); + // clear the entire buffer + void (*clear) (lm_ggml_backend_buffer_t buffer, uint8_t value); + // (optional) reset any internal state due to tensor initialization, such as tensor extras + void (*reset) (lm_ggml_backend_buffer_t buffer); }; struct lm_ggml_backend_buffer { struct lm_ggml_backend_buffer_i iface; lm_ggml_backend_buffer_type_t buft; - lm_ggml_backend_buffer_context_t context; + void * context; size_t size; enum lm_ggml_backend_buffer_usage usage; }; - LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( - lm_ggml_backend_buffer_type_t buft, - struct lm_ggml_backend_buffer_i iface, - lm_ggml_backend_buffer_context_t context, - size_t size); + lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( + lm_ggml_backend_buffer_type_t buft, + struct lm_ggml_backend_buffer_i iface, + void * context, + size_t size); // do not use directly, use lm_ggml_backend_tensor_copy instead bool lm_ggml_backend_buffer_copy_tensor(const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); + // multi-buffer // buffer that contains a collection of buffers - LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers); - LM_GGML_CALL bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer); - LM_GGML_CALL void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); + lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers); + bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer); + void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); // - // Backend + // Backend (stream) // - typedef void * lm_ggml_backend_context_t; - struct lm_ggml_backend_i { - const char * (*LM_GGML_CALL get_name)(lm_ggml_backend_t backend); - - void (*LM_GGML_CALL free)(lm_ggml_backend_t backend); + const char * (*get_name)(lm_ggml_backend_t backend); - // buffer allocation - lm_ggml_backend_buffer_type_t (*LM_GGML_CALL get_default_buffer_type)(lm_ggml_backend_t backend); + void (*free)(lm_ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*LM_GGML_CALL set_tensor_async)(lm_ggml_backend_t backend, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*LM_GGML_CALL get_tensor_async)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*LM_GGML_CALL cpy_tensor_async)(lm_ggml_backend_t backend_src, lm_ggml_backend_t backend_dst, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); + void (*set_tensor_async)(lm_ggml_backend_t backend, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*cpy_tensor_async)(lm_ggml_backend_t backend_src, lm_ggml_backend_t backend_dst, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); - // (optional) complete all pending operations - void (*LM_GGML_CALL synchronize)(lm_ggml_backend_t backend); + // (optional) complete all pending operations (required if the backend supports async operations) + void (*synchronize)(lm_ggml_backend_t backend); - // compute graph with a plan (not used currently) - // create a new plan for a graph - lm_ggml_backend_graph_plan_t (*LM_GGML_CALL graph_plan_create) (lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph); - void (*LM_GGML_CALL graph_plan_free) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); + // (optional) graph plans (not used currently) + // compute graph with a plan + lm_ggml_backend_graph_plan_t (*graph_plan_create) (lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph); + void (*graph_plan_free) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology - void (*LM_GGML_CALL graph_plan_update) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan, const struct lm_ggml_cgraph * cgraph); + void (*graph_plan_update) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan, const struct lm_ggml_cgraph * cgraph); // compute the graph with the plan - enum lm_ggml_status (*LM_GGML_CALL graph_plan_compute)(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); + enum lm_ggml_status (*graph_plan_compute)(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); - // compute graph without a plan (async) - enum lm_ggml_status (*LM_GGML_CALL graph_compute) (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); - - // check if the backend can compute an operation - bool (*LM_GGML_CALL supports_op)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); - - // check if the backend can use tensors allocated in a buffer type - bool (*LM_GGML_CALL supports_buft)(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft); - - // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer - // these should be expensive operations with large batch sizes that may benefit from running on this backend - // even if the weight has to be copied from the CPU temporarily - bool (*LM_GGML_CALL offload_op)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); + // compute graph (always async if supported by the backend) + enum lm_ggml_status (*graph_compute) (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); // (optional) event synchronization - // create a new event that can record events on this backend instance - lm_ggml_backend_event_t (*LM_GGML_CALL event_new) (lm_ggml_backend_t backend); - void (*LM_GGML_CALL event_free) (lm_ggml_backend_event_t event); - // record an event on the backend instance that created it - void (*LM_GGML_CALL event_record) (lm_ggml_backend_event_t event); - // wait for an event on on a different backend instance - void (*LM_GGML_CALL event_wait) (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); - // block until an event is recorded - void (*LM_GGML_CALL event_synchronize) (lm_ggml_backend_event_t event); + // record an event on this stream + void (*event_record)(lm_ggml_backend_t backend, lm_ggml_backend_event_t event); + // wait for an event on on a different stream + void (*event_wait) (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); }; struct lm_ggml_backend { lm_ggml_guid_t guid; - struct lm_ggml_backend_i iface; - lm_ggml_backend_context_t context; + lm_ggml_backend_dev_t device; + void * context; }; struct lm_ggml_backend_event { - lm_ggml_backend_t backend; + struct lm_ggml_backend_device * device; + void * context; + }; + + // + // Backend device + // + + // Note: if additional properties are needed, we should add a struct with all of them + // the current functions to obtain the properties can remain, since they are more convenient for often used properties + struct lm_ggml_backend_device_i { + // device name: short identifier for this device, such as "CPU" or "CUDA0" + const char * (*get_name)(lm_ggml_backend_dev_t dev); + + // device description: short informative description of the device, could be the model name + const char * (*get_description)(lm_ggml_backend_dev_t dev); + + // device memory in bytes + void (*get_memory)(lm_ggml_backend_dev_t dev, size_t * free, size_t * total); + + // device type + enum lm_ggml_backend_dev_type (*get_type)(lm_ggml_backend_dev_t dev); + + // device properties + void (*get_props)(lm_ggml_backend_dev_t dev, struct lm_ggml_backend_dev_props * props); + + // backend (stream) initialization + lm_ggml_backend_t (*init_backend)(lm_ggml_backend_dev_t dev, const char * params); + + // preferred buffer type + lm_ggml_backend_buffer_type_t (*get_buffer_type)(lm_ggml_backend_dev_t dev); + + // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device) + lm_ggml_backend_buffer_type_t (*get_host_buffer_type)(lm_ggml_backend_dev_t dev); + + // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries) + lm_ggml_backend_buffer_t (*buffer_from_host_ptr)(lm_ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size); + + // check if the backend can compute an operation + bool (*supports_op)(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op); + + // check if the backend can use tensors allocated in a buffer type + bool (*supports_buft)(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft); + + // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer + // these should be expensive operations that may benefit from running on this backend instead of the CPU backend + bool (*offload_op)(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op); + + // (optional) event synchronization + lm_ggml_backend_event_t (*event_new) (lm_ggml_backend_dev_t dev); + void (*event_free) (lm_ggml_backend_dev_t dev, lm_ggml_backend_event_t event); + void (*event_synchronize) (lm_ggml_backend_dev_t dev, lm_ggml_backend_event_t event); + }; + + struct lm_ggml_backend_device { + struct lm_ggml_backend_device_i iface; + lm_ggml_backend_reg_t reg; void * context; }; // - // Backend registry + // Backend (reg) // - typedef lm_ggml_backend_t (*LM_GGML_CALL lm_ggml_backend_init_fn)(const char * params, void * user_data); + struct lm_ggml_backend_reg_i { + const char * (*get_name)(lm_ggml_backend_reg_t reg); + + // enumerate available devices + size_t (*get_device_count)(lm_ggml_backend_reg_t reg); + lm_ggml_backend_dev_t (*get_device)(lm_ggml_backend_reg_t reg, size_t index); + + // (optional) get a pointer to a function in the backend + // backends can add custom functions that are not part of the standard ggml-backend interface + void * (*get_proc_address)(lm_ggml_backend_reg_t reg, const char * name); + }; + + struct lm_ggml_backend_reg { + // int api_version; // TODO: for dynamic loading + struct lm_ggml_backend_reg_i iface; + void * context; + }; + - LM_GGML_CALL void lm_ggml_backend_register(const char * name, lm_ggml_backend_init_fn init_fn, lm_ggml_backend_buffer_type_t default_buffer_type, void * user_data); + // Internal backend registry API + void lm_ggml_backend_register(lm_ggml_backend_reg_t reg); + void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device); + // TODO: backends can be loaded as a dynamic library, in which case it needs to export this function + // typedef lm_ggml_backend_register_t * (*lm_ggml_backend_init)(void); #ifdef __cplusplus } diff --git a/cpp/ggml-backend.c b/cpp/ggml-backend.cpp similarity index 69% rename from cpp/ggml-backend.c rename to cpp/ggml-backend.cpp index 12c540c..7451de3 100644 --- a/cpp/ggml-backend.c +++ b/cpp/ggml-backend.cpp @@ -1,3 +1,13 @@ +// Note: porting this file to C++ is a work in progress + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-impl.h" @@ -8,9 +18,14 @@ #include #include #include +#include +#include +#ifdef __APPLE__ +#include +#include +#endif -#define MAX(a, b) ((a) > (b) ? (a) : (b)) // backend buffer type @@ -18,7 +33,12 @@ const char * lm_ggml_backend_buft_name(lm_ggml_backend_buffer_type_t buft) { return buft->iface.get_name(buft); } -LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { +lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + if (size == 0) { + // return a dummy buffer for zero-sized allocations + return lm_ggml_backend_buffer_init(buft, {}, NULL, 0); + } + return buft->iface.alloc_buffer(buft, size); } @@ -34,7 +54,7 @@ size_t lm_ggml_backend_buft_get_max_size(lm_ggml_backend_buffer_type_t buft) { return SIZE_MAX; } -LM_GGML_CALL size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor) { +size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor) { // get_alloc_size is optional, defaults to lm_ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); @@ -51,16 +71,18 @@ bool lm_ggml_backend_buft_is_host(lm_ggml_backend_buffer_type_t buft) { return false; } -// backend buffer +lm_ggml_backend_dev_t lm_ggml_backend_buft_get_device(lm_ggml_backend_buffer_type_t buft) { + return buft->device; +} -LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( - lm_ggml_backend_buffer_type_t buft, - struct lm_ggml_backend_buffer_i iface, - lm_ggml_backend_buffer_context_t context, - size_t size) { - lm_ggml_backend_buffer_t buffer = malloc(sizeof(struct lm_ggml_backend_buffer)); +// backend buffer - (*buffer) = (struct lm_ggml_backend_buffer) { +lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( + lm_ggml_backend_buffer_type_t buft, + struct lm_ggml_backend_buffer_i iface, + void * context, + size_t size) { + lm_ggml_backend_buffer_t buffer = new lm_ggml_backend_buffer { /* .interface = */ iface, /* .buft = */ buft, /* .context = */ context, @@ -72,7 +94,7 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( } const char * lm_ggml_backend_buffer_name(lm_ggml_backend_buffer_t buffer) { - return buffer->iface.get_name(buffer); + return lm_ggml_backend_buft_name(lm_ggml_backend_buffer_get_type(buffer)); } void lm_ggml_backend_buffer_free(lm_ggml_backend_buffer_t buffer) { @@ -83,7 +105,7 @@ void lm_ggml_backend_buffer_free(lm_ggml_backend_buffer_t buffer) { if (buffer->iface.free_buffer != NULL) { buffer->iface.free_buffer(buffer); } - free(buffer); + delete buffer; } size_t lm_ggml_backend_buffer_get_size(lm_ggml_backend_buffer_t buffer) { @@ -91,6 +113,11 @@ size_t lm_ggml_backend_buffer_get_size(lm_ggml_backend_buffer_t buffer) { } void * lm_ggml_backend_buffer_get_base(lm_ggml_backend_buffer_t buffer) { + // get_base is optional if the buffer is zero-sized + if (buffer->size == 0) { + return NULL; + } + void * base = buffer->iface.get_base(buffer); LM_GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); @@ -98,14 +125,23 @@ void * lm_ggml_backend_buffer_get_base(lm_ggml_backend_buffer_t buffer) { return base; } -LM_GGML_CALL void lm_ggml_backend_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { +void lm_ggml_backend_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { // init_tensor is optional if (buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } } -size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer) { +void lm_ggml_backend_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { + // clear is optional if the buffer is zero-sized + if (buffer->size == 0) { + return; + } + + buffer->iface.clear(buffer, value); +} + +size_t lm_ggml_backend_buffer_get_alignment(lm_ggml_backend_buffer_t buffer) { return lm_ggml_backend_buft_get_alignment(lm_ggml_backend_buffer_get_type(buffer)); } @@ -117,10 +153,6 @@ size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, st return lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_get_type(buffer), tensor); } -void lm_ggml_backend_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { - buffer->iface.clear(buffer, value); -} - bool lm_ggml_backend_buffer_is_host(lm_ggml_backend_buffer_t buffer) { return lm_ggml_backend_buft_is_host(lm_ggml_backend_buffer_get_type(buffer)); } @@ -181,7 +213,7 @@ void lm_ggml_backend_free(lm_ggml_backend_t backend) { } lm_ggml_backend_buffer_type_t lm_ggml_backend_get_default_buffer_type(lm_ggml_backend_t backend) { - return backend->iface.get_default_buffer_type(backend); + return lm_ggml_backend_dev_buffer_type(backend->device); } lm_ggml_backend_buffer_t lm_ggml_backend_alloc_buffer(lm_ggml_backend_t backend, size_t size) { @@ -218,32 +250,47 @@ void lm_ggml_backend_tensor_get_async(lm_ggml_backend_t backend, const struct lm } } -LM_GGML_CALL void lm_ggml_backend_tensor_set(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +void lm_ggml_backend_tensor_set(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (size == 0) { + return; + } + LM_GGML_ASSERT(buf != NULL && "tensor buffer not set"); LM_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); LM_GGML_ASSERT(offset + size <= lm_ggml_nbytes(tensor) && "tensor write out of bounds"); - if (!size) { - return; - } - buf->iface.set_tensor(buf, tensor, data, offset, size); } -LM_GGML_CALL void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { +void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (size == 0) { + return; + } + LM_GGML_ASSERT(buf != NULL && "tensor buffer not set"); LM_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); LM_GGML_ASSERT(offset + size <= lm_ggml_nbytes(tensor) && "tensor read out of bounds"); - if (!size) { + buf->iface.get_tensor(buf, tensor, data, offset, size); +} + +LM_GGML_API void lm_ggml_backend_tensor_memset(struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + if (size == 0) { return; } - buf->iface.get_tensor(buf, tensor, data, offset, size); + LM_GGML_ASSERT(buf != NULL && "tensor buffer not set"); + LM_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + LM_GGML_ASSERT(offset + size <= lm_ggml_nbytes(tensor) && "tensor write out of bounds"); + LM_GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not implemented by backend buffer"); + + buf->iface.memset_tensor(buf, tensor, value, offset, size); } void lm_ggml_backend_synchronize(lm_ggml_backend_t backend) { @@ -283,18 +330,19 @@ enum lm_ggml_status lm_ggml_backend_graph_compute_async(lm_ggml_backend_t backen } bool lm_ggml_backend_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) { - return backend->iface.supports_op(backend, op); + return lm_ggml_backend_dev_supports_op(backend->device, op); } bool lm_ggml_backend_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { - return backend->iface.supports_buft(backend, buft); + return lm_ggml_backend_dev_supports_buft(backend->device, buft); } bool lm_ggml_backend_offload_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) { - if (backend->iface.offload_op != NULL) { - return backend->iface.offload_op(backend, op); - } - return false; + return lm_ggml_backend_dev_offload_op(backend->device, op); +} + +lm_ggml_backend_dev_t lm_ggml_backend_get_device(lm_ggml_backend_t backend) { + return backend->device; } // backend copy @@ -327,7 +375,7 @@ void lm_ggml_backend_tensor_copy(struct lm_ggml_tensor * src, struct lm_ggml_ten lm_ggml_backend_tensor_get(src, dst->data, 0, lm_ggml_nbytes(src)); } else if (!lm_ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG - fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, lm_ggml_backend_buffer_name(src->buffer), lm_ggml_backend_buffer_name(dst->buffer)); + LM_GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, lm_ggml_backend_buffer_name(src->buffer), lm_ggml_backend_buffer_name(dst->buffer)); #endif size_t nbytes = lm_ggml_nbytes(src); void * data = malloc(nbytes); @@ -359,30 +407,31 @@ void lm_ggml_backend_tensor_copy_async(lm_ggml_backend_t backend_src, lm_ggml_ba // events -lm_ggml_backend_event_t lm_ggml_backend_event_new(lm_ggml_backend_t backend) { - if (backend->iface.event_new == NULL) { +lm_ggml_backend_event_t lm_ggml_backend_event_new(lm_ggml_backend_dev_t device) { + // null device is allowed for the transition period to the device interface + if (device == NULL || device->iface.event_new == NULL) { return NULL; } - return backend->iface.event_new(backend); + return device->iface.event_new(device); } void lm_ggml_backend_event_free(lm_ggml_backend_event_t event) { if (event == NULL) { return; } - event->backend->iface.event_free(event); + event->device->iface.event_free(event->device, event); } -void lm_ggml_backend_event_record(lm_ggml_backend_event_t event) { - LM_GGML_ASSERT(event->backend->iface.event_record != NULL); +void lm_ggml_backend_event_record(lm_ggml_backend_event_t event, lm_ggml_backend_t backend) { + LM_GGML_ASSERT(backend->iface.event_record != NULL); - event->backend->iface.event_record(event); + backend->iface.event_record(backend, event); } void lm_ggml_backend_event_synchronize(lm_ggml_backend_event_t event) { - LM_GGML_ASSERT(event->backend->iface.event_synchronize != NULL); + LM_GGML_ASSERT(event->device->iface.event_synchronize); - event->backend->iface.event_synchronize(event); + event->device->iface.event_synchronize(event->device, event); } void lm_ggml_backend_event_wait(lm_ggml_backend_t backend, lm_ggml_backend_event_t event) { @@ -391,170 +440,282 @@ void lm_ggml_backend_event_wait(lm_ggml_backend_t backend, lm_ggml_backend_event backend->iface.event_wait(backend, event); } -// backend registry +// Backend device -#define LM_GGML_REG_MAX_BACKENDS 64 +const char * lm_ggml_backend_dev_name(lm_ggml_backend_dev_t device) { + return device->iface.get_name(device); +} -struct lm_ggml_backend_reg { - char name[128]; - lm_ggml_backend_init_fn init_fn; - lm_ggml_backend_buffer_type_t default_buffer_type; - void * user_data; -}; +const char * lm_ggml_backend_dev_description(lm_ggml_backend_dev_t device) { + return device->iface.get_description(device); +} -static struct lm_ggml_backend_reg lm_ggml_backend_registry[LM_GGML_REG_MAX_BACKENDS]; -static size_t lm_ggml_backend_registry_count = 0; +void lm_ggml_backend_dev_memory(lm_ggml_backend_dev_t device, size_t * free, size_t * total) { + device->iface.get_memory(device, free, total); +} -LM_GGML_CALL static lm_ggml_backend_t lm_ggml_backend_reg_cpu_init(const char * params, void * user_data); +enum lm_ggml_backend_dev_type lm_ggml_backend_dev_type(lm_ggml_backend_dev_t device) { + return device->iface.get_type(device); +} -LM_GGML_CALL static void lm_ggml_backend_registry_init(void) { - static bool initialized = false; +void lm_ggml_backend_dev_get_props(lm_ggml_backend_dev_t device, struct lm_ggml_backend_dev_props * props) { + memset(props, 0, sizeof(*props)); + device->iface.get_props(device, props); +} - if (initialized) { - return; +lm_ggml_backend_reg_t lm_ggml_backend_dev_backend_reg(lm_ggml_backend_dev_t device) { + return device->reg; +} + +lm_ggml_backend_t lm_ggml_backend_dev_init(lm_ggml_backend_dev_t device, const char * params) { + return device->iface.init_backend(device, params); +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_dev_buffer_type(lm_ggml_backend_dev_t device) { + return device->iface.get_buffer_type(device); +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_dev_host_buffer_type(lm_ggml_backend_dev_t device) { + if (device->iface.get_host_buffer_type == NULL) { + return NULL; + } + + return device->iface.get_host_buffer_type(device); +} + +lm_ggml_backend_buffer_t lm_ggml_backend_dev_buffer_from_host_ptr(lm_ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { + return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); +} + +bool lm_ggml_backend_dev_supports_op(lm_ggml_backend_dev_t device, const struct lm_ggml_tensor * op) { + return device->iface.supports_op(device, op); +} + +bool lm_ggml_backend_dev_supports_buft(lm_ggml_backend_dev_t device, lm_ggml_backend_buffer_type_t buft) { + return device->iface.supports_buft(device, buft); +} + +bool lm_ggml_backend_dev_offload_op(lm_ggml_backend_dev_t device, const struct lm_ggml_tensor * op) { + if (device->iface.offload_op != NULL) { + return device->iface.offload_op(device, op); } - initialized = true; + return false; +} + +// Backend (reg) + +const char * lm_ggml_backend_reg_name(lm_ggml_backend_reg_t reg) { + return reg->iface.get_name(reg); +} + +size_t lm_ggml_backend_reg_dev_count(lm_ggml_backend_reg_t reg) { + return reg->iface.get_device_count(reg); +} + +lm_ggml_backend_dev_t lm_ggml_backend_reg_dev_get(lm_ggml_backend_reg_t reg, size_t index) { + return reg->iface.get_device(reg, index); +} + +void * lm_ggml_backend_reg_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { + if (!reg->iface.get_proc_address) { + return NULL; + } + return reg->iface.get_proc_address(reg, name); +} - lm_ggml_backend_register("CPU", lm_ggml_backend_reg_cpu_init, lm_ggml_backend_cpu_buffer_type(), NULL); +// Backend registry - // add forward decls here to avoid including the backend headers #ifdef LM_GGML_USE_CUDA - extern LM_GGML_CALL void lm_ggml_backend_cuda_reg_devices(void); - lm_ggml_backend_cuda_reg_devices(); +#include "ggml-cuda.h" #endif -#ifdef LM_GGML_USE_SYCL - extern void lm_ggml_backend_sycl_reg_devices(void); - lm_ggml_backend_sycl_reg_devices(); +#ifdef LM_GGML_USE_METAL +#include "ggml-metal.h" #endif -#ifdef LM_GGML_USE_METAL - extern LM_GGML_CALL lm_ggml_backend_t lm_ggml_backend_reg_metal_init(const char * params, void * user_data); - extern LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void); - lm_ggml_backend_register("Metal", lm_ggml_backend_reg_metal_init, lm_ggml_backend_metal_buffer_type(), NULL); +#ifdef LM_GGML_USE_SYCL +#include "ggml-sycl.h" #endif #ifdef LM_GGML_USE_VULKAN - extern LM_GGML_CALL int lm_ggml_backend_vk_reg_devices(void); - lm_ggml_backend_vk_reg_devices(); +#include "ggml-vulkan.h" #endif -#ifdef LM_GGML_USE_KOMPUTE - extern LM_GGML_CALL void lm_ggml_backend_kompute_reg_devices(void); - lm_ggml_backend_kompute_reg_devices(); +#ifdef LM_GGML_USE_BLAS +#include "ggml-blas.h" #endif -#ifdef LM_GGML_USE_CANN - extern LM_GGML_CALL int lm_ggml_backend_cann_reg_devices(void); - lm_ggml_backend_cann_reg_devices(); +#ifdef LM_GGML_USE_RPC +#include "ggml-rpc.h" #endif -} - -LM_GGML_CALL void lm_ggml_backend_register(const char * name, lm_ggml_backend_init_fn init_fn, lm_ggml_backend_buffer_type_t default_buffer_type, void * user_data) { - LM_GGML_ASSERT(lm_ggml_backend_registry_count < LM_GGML_REG_MAX_BACKENDS); - - size_t id = lm_ggml_backend_registry_count; - - lm_ggml_backend_registry[id] = (struct lm_ggml_backend_reg) { - /* .name = */ {0}, - /* .fn = */ init_fn, - /* .default_buffer_type = */ default_buffer_type, - /* .user_data = */ user_data, - }; - snprintf(lm_ggml_backend_registry[id].name, sizeof(lm_ggml_backend_registry[id].name), "%s", name); +#ifndef __AMX_INT8__ +#undef LM_GGML_USE_AMX +#endif -#ifndef NDEBUG - fprintf(stderr, "%s: registered backend %s\n", __func__, name); +#ifdef LM_GGML_USE_AMX +# include "ggml-amx.h" #endif - lm_ggml_backend_registry_count++; -} +#ifdef LM_GGML_USE_CANN +#include "ggml-cann.h" +#endif -size_t lm_ggml_backend_reg_get_count(void) { - lm_ggml_backend_registry_init(); +#ifdef LM_GGML_USE_KOMPUTE +#include "ggml-kompute.h" +#endif - return lm_ggml_backend_registry_count; -} +struct lm_ggml_backend_registry { + std::vector backends; + std::vector devices; -size_t lm_ggml_backend_reg_find_by_name(const char * name) { - lm_ggml_backend_registry_init(); + lm_ggml_backend_registry() { +#ifdef LM_GGML_USE_CUDA + register_backend(lm_ggml_backend_cuda_reg()); +#endif +#ifdef LM_GGML_USE_METAL + register_backend(lm_ggml_backend_metal_reg()); +#endif +#ifdef LM_GGML_USE_SYCL + register_backend(lm_ggml_backend_sycl_reg()); +#endif +#ifdef LM_GGML_USE_VULKAN + register_backend(lm_ggml_backend_vk_reg()); +#endif +#ifdef LM_GGML_USE_CANN + register_backend(lm_ggml_backend_cann_reg()); +#endif +#ifdef LM_GGML_USE_BLAS + register_backend(lm_ggml_backend_blas_reg()); +#endif +#ifdef LM_GGML_USE_RPC + register_backend(lm_ggml_backend_rpc_reg()); +#endif +#ifdef LM_GGML_USE_AMX + register_backend(lm_ggml_backend_amx_reg()); +#endif +#ifdef LM_GGML_USE_KOMPUTE + register_backend(lm_ggml_backend_kompute_reg()); +#endif - for (size_t i = 0; i < lm_ggml_backend_registry_count; i++) { - // TODO: case insensitive in a portable way - if (strcmp(lm_ggml_backend_registry[i].name, name) == 0) { - return i; - } + register_backend(lm_ggml_backend_cpu_reg()); } - // not found - return SIZE_MAX; -} - -// init from backend:params string -lm_ggml_backend_t lm_ggml_backend_reg_init_backend_from_str(const char * backend_str) { - lm_ggml_backend_registry_init(); - - const char * params = strchr(backend_str, ':'); - char backend_name[128]; - if (params == NULL) { - snprintf(backend_name, sizeof(backend_name), "%s", backend_str); - params = ""; - } else { - snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str); - params++; + void register_backend(lm_ggml_backend_reg_t reg) { +#ifndef NDEBUG + LM_GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", + __func__, lm_ggml_backend_reg_name(reg), lm_ggml_backend_reg_dev_count(reg)); +#endif + backends.push_back(reg); + for (size_t i = 0; i < lm_ggml_backend_reg_dev_count(reg); i++) { + register_device(lm_ggml_backend_reg_dev_get(reg, i)); + } } - size_t backend_i = lm_ggml_backend_reg_find_by_name(backend_name); - - if (backend_i == SIZE_MAX) { - fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); - return NULL; + void register_device(lm_ggml_backend_dev_t device) { +#ifndef NDEBUG + LM_GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, lm_ggml_backend_dev_name(device), lm_ggml_backend_dev_description(device)); +#endif + devices.push_back(device); } +}; - return lm_ggml_backend_reg_init_backend(backend_i, params); +static lm_ggml_backend_registry & get_reg() { + static lm_ggml_backend_registry reg; + return reg; } -const char * lm_ggml_backend_reg_get_name(size_t i) { - lm_ggml_backend_registry_init(); +// Internal API +void lm_ggml_backend_register(lm_ggml_backend_reg_t reg) { + get_reg().register_backend(reg); +} - LM_GGML_ASSERT(i < lm_ggml_backend_registry_count); - return lm_ggml_backend_registry[i].name; +void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device) { + get_reg().register_device(device); } -lm_ggml_backend_t lm_ggml_backend_reg_init_backend(size_t i, const char * params) { - lm_ggml_backend_registry_init(); +// Backend (reg) enumeration +size_t lm_ggml_backend_reg_count() { + return get_reg().backends.size(); +} - LM_GGML_ASSERT(i < lm_ggml_backend_registry_count); - return lm_ggml_backend_registry[i].init_fn(params, lm_ggml_backend_registry[i].user_data); +lm_ggml_backend_reg_t lm_ggml_backend_reg_get(size_t index) { + LM_GGML_ASSERT(index < lm_ggml_backend_reg_count()); + return get_reg().backends[index]; } -lm_ggml_backend_buffer_type_t lm_ggml_backend_reg_get_default_buffer_type(size_t i) { - lm_ggml_backend_registry_init(); +lm_ggml_backend_reg_t lm_ggml_backend_reg_by_name(const char * name) { + for (size_t i = 0; i < lm_ggml_backend_reg_count(); i++) { + lm_ggml_backend_reg_t reg = lm_ggml_backend_reg_get(i); + if (strcmp(lm_ggml_backend_reg_name(reg), name) == 0) { + return reg; + } + } + return NULL; +} - LM_GGML_ASSERT(i < lm_ggml_backend_registry_count); - return lm_ggml_backend_registry[i].default_buffer_type; +// Device enumeration +size_t lm_ggml_backend_dev_count() { + return get_reg().devices.size(); } -lm_ggml_backend_buffer_t lm_ggml_backend_reg_alloc_buffer(size_t i, size_t size) { - lm_ggml_backend_registry_init(); +lm_ggml_backend_dev_t lm_ggml_backend_dev_get(size_t index) { + LM_GGML_ASSERT(index < lm_ggml_backend_dev_count()); + return get_reg().devices[index]; +} - LM_GGML_ASSERT(i < lm_ggml_backend_registry_count); - return lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_registry[i].default_buffer_type, size); +lm_ggml_backend_dev_t lm_ggml_backend_dev_by_name(const char * name) { + for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); + if (strcmp(lm_ggml_backend_dev_name(dev), name) == 0) { + return dev; + } + } + return NULL; } -// backend CPU +lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type type) { + for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); + if (lm_ggml_backend_dev_type(dev) == type) { + return dev; + } + } + return NULL; +} -static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment +// Convenience functions +lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * params) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_name(name); + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, params); +} -LM_GGML_CALL static const char * lm_ggml_backend_cpu_buffer_name(lm_ggml_backend_buffer_t buffer) { - return "CPU"; +lm_ggml_backend_t lm_ggml_backend_init_by_type(enum lm_ggml_backend_dev_type type, const char * params) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(type); + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, params); +} - LM_GGML_UNUSED(buffer); +lm_ggml_backend_t lm_ggml_backend_init_best(void) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_GPU); + if (!dev) { + dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); + } + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, NULL); } -LM_GGML_CALL static void * lm_ggml_backend_cpu_buffer_get_base(lm_ggml_backend_buffer_t buffer) { +// CPU backend - buffer + +static void * lm_ggml_backend_cpu_buffer_get_base(lm_ggml_backend_buffer_t buffer) { uintptr_t data = (uintptr_t)buffer->context; // align the buffer @@ -565,23 +726,29 @@ LM_GGML_CALL static void * lm_ggml_backend_cpu_buffer_get_base(lm_ggml_backend_b return (void *)data; } -LM_GGML_CALL static void lm_ggml_backend_cpu_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { - free(buffer->context); +static void lm_ggml_backend_cpu_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { + lm_ggml_aligned_free(buffer->context, buffer->size); } -LM_GGML_CALL static void lm_ggml_backend_cpu_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void lm_ggml_backend_cpu_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); LM_GGML_UNUSED(buffer); } -LM_GGML_CALL static void lm_ggml_backend_cpu_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void lm_ggml_backend_cpu_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); LM_GGML_UNUSED(buffer); } -LM_GGML_CALL static bool lm_ggml_backend_cpu_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { +static bool lm_ggml_backend_cpu_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { if (lm_ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, lm_ggml_nbytes(src)); return true; @@ -591,15 +758,15 @@ LM_GGML_CALL static bool lm_ggml_backend_cpu_buffer_cpy_tensor(lm_ggml_backend_b LM_GGML_UNUSED(buffer); } -LM_GGML_CALL static void lm_ggml_backend_cpu_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { +static void lm_ggml_backend_cpu_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { memset(buffer->context, value, buffer->size); } -static struct lm_ggml_backend_buffer_i cpu_backend_buffer_i = { - /* .get_name = */ lm_ggml_backend_cpu_buffer_name, +static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_i = { /* .free_buffer = */ lm_ggml_backend_cpu_buffer_free_buffer, /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, @@ -607,12 +774,11 @@ static struct lm_ggml_backend_buffer_i cpu_backend_buffer_i = { /* .reset = */ NULL, }; -// for buffers from ptr, free is not called -static struct lm_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { - /* .get_name = */ lm_ggml_backend_cpu_buffer_name, +static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_from_ptr_i = { /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, @@ -620,38 +786,40 @@ static struct lm_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { /* .reset = */ NULL, }; -LM_GGML_CALL static const char * lm_ggml_backend_cpu_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { +// CPU backend - buffer type + +static const char * lm_ggml_backend_cpu_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { return "CPU"; LM_GGML_UNUSED(buft); } -LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: use LM_GGML_ALIGNED_MALLOC (move to ggml-impl.h) +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + void * data = lm_ggml_aligned_malloc(size); + if (data == NULL) { - fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + LM_GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); return NULL; } - return lm_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); + return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_cpu_buffer_i, data, size); } -LM_GGML_CALL static size_t lm_ggml_backend_cpu_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { +static size_t lm_ggml_backend_cpu_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { return TENSOR_ALIGNMENT; LM_GGML_UNUSED(buft); } -LM_GGML_CALL static bool lm_ggml_backend_cpu_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { +static bool lm_ggml_backend_cpu_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { return true; LM_GGML_UNUSED(buft); } -LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) { +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) { static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { - /* .iface = */ { + /* .iface = */ { /* .get_name = */ lm_ggml_backend_cpu_buffer_type_get_name, /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, @@ -659,6 +827,30 @@ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, }, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_buffer_type; +} + +static const char * lm_ggml_backend_cpu_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU_Mapped"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_from_ptr_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), /* .context = */ NULL, }; @@ -671,34 +863,26 @@ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) #include -LM_GGML_CALL static const char * lm_ggml_backend_cpu_hbm_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { +static const char * lm_ggml_backend_cpu_hbm_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { return "CPU_HBM"; LM_GGML_UNUSED(buft); } -LM_GGML_CALL static const char * lm_ggml_backend_cpu_hbm_buffer_get_name(lm_ggml_backend_buffer_t buf) { - return "CPU_HBM"; - - LM_GGML_UNUSED(buf); -} - -LM_GGML_CALL static void lm_ggml_backend_cpu_hbm_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { +static void lm_ggml_backend_cpu_hbm_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { hbw_free(buffer->context); } -LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - //void * ptr = hbw_malloc(size); +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { void * ptr; int result = hbw_posix_memalign(&ptr, lm_ggml_backend_cpu_buffer_type_get_alignment(buft), size); if (result != 0) { - fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); + LM_GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); return NULL; } lm_ggml_backend_buffer_t buffer = lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); buffer->buft = buft; - buffer->iface.get_name = lm_ggml_backend_cpu_hbm_buffer_get_name; buffer->iface.free_buffer = lm_ggml_backend_cpu_hbm_buffer_free_buffer; return buffer; @@ -721,32 +905,43 @@ lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void) { } #endif +static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_get_extra_bufts(lm_ggml_backend_dev_t device) { + static lm_ggml_backend_buffer_type_t bufts[] = { +#ifdef LM_GGML_USE_CPU_HBM + lm_ggml_backend_cpu_hbm_buffer_type(), +#endif + NULL + }; + + return bufts; + + LM_GGML_UNUSED(device); +} + +// CPU backend - backend (stream) + struct lm_ggml_backend_cpu_context { - int n_threads; - void * work_data; - size_t work_size; + int n_threads; + lm_ggml_threadpool_t threadpool; + + uint8_t * work_data; + size_t work_size; lm_ggml_abort_callback abort_callback; void * abort_callback_data; }; -LM_GGML_CALL static const char * lm_ggml_backend_cpu_name(lm_ggml_backend_t backend) { +static const char * lm_ggml_backend_cpu_get_name(lm_ggml_backend_t backend) { return "CPU"; LM_GGML_UNUSED(backend); } -LM_GGML_CALL static void lm_ggml_backend_cpu_free(lm_ggml_backend_t backend) { +static void lm_ggml_backend_cpu_free(lm_ggml_backend_t backend) { struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - free(cpu_ctx->work_data); - free(cpu_ctx); - free(backend); -} - -LM_GGML_CALL static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_get_default_buffer_type(lm_ggml_backend_t backend) { - return lm_ggml_backend_cpu_buffer_type(); - - LM_GGML_UNUSED(backend); + delete[] cpu_ctx->work_data; + delete cpu_ctx; + delete backend; } struct lm_ggml_backend_plan_cpu { @@ -754,18 +949,18 @@ struct lm_ggml_backend_plan_cpu { struct lm_ggml_cgraph cgraph; }; -LM_GGML_CALL static lm_ggml_backend_graph_plan_t lm_ggml_backend_cpu_graph_plan_create(lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph) { +static lm_ggml_backend_graph_plan_t lm_ggml_backend_cpu_graph_plan_create(lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph) { struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - struct lm_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct lm_ggml_backend_plan_cpu)); + struct lm_ggml_backend_plan_cpu * cpu_plan = new lm_ggml_backend_plan_cpu; - cpu_plan->cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); cpu_plan->cgraph = *cgraph; // FIXME: deep copy if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size]; if (cpu_plan->cplan.work_data == NULL) { - free(cpu_plan); + delete cpu_plan; return NULL; } } @@ -776,16 +971,16 @@ LM_GGML_CALL static lm_ggml_backend_graph_plan_t lm_ggml_backend_cpu_graph_plan_ return cpu_plan; } -LM_GGML_CALL static void lm_ggml_backend_cpu_graph_plan_free(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { +static void lm_ggml_backend_cpu_graph_plan_free(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; - free(cpu_plan->cplan.work_data); - free(cpu_plan); + delete[] cpu_plan->cplan.work_data; + delete cpu_plan; LM_GGML_UNUSED(backend); } -LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_cpu_graph_plan_compute(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { +static enum lm_ggml_status lm_ggml_backend_cpu_graph_plan_compute(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; return lm_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); @@ -793,21 +988,21 @@ LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_cpu_graph_plan_compute(l LM_GGML_UNUSED(backend); } -LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_cpu_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { +static enum lm_ggml_status lm_ggml_backend_cpu_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads); + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); if (cpu_ctx->work_size < cplan.work_size) { - free(cpu_ctx->work_data); - cpu_ctx->work_data = malloc(cplan.work_size); + delete[] cpu_ctx->work_data; + cpu_ctx->work_data = new uint8_t[cplan.work_size]; if (cpu_ctx->work_data == NULL) { cpu_ctx->work_size = 0; return LM_GGML_STATUS_ALLOC_FAILED; } cpu_ctx->work_size = cplan.work_size; } - cplan.work_data = cpu_ctx->work_data; + cplan.work_data = (uint8_t *)cpu_ctx->work_data; cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback_data = cpu_ctx->abort_callback_data; @@ -815,33 +1010,9 @@ LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_cpu_graph_compute(lm_ggm return lm_ggml_graph_compute(cgraph, &cplan); } -LM_GGML_CALL static bool lm_ggml_backend_cpu_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) { - switch (op->op) { - case LM_GGML_OP_CPY: - return - op->type != LM_GGML_TYPE_IQ2_XXS && - op->type != LM_GGML_TYPE_IQ2_XS && - op->type != LM_GGML_TYPE_IQ1_S && - op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float - case LM_GGML_OP_MUL_MAT: - return op->src[1]->type == LM_GGML_TYPE_F32 || op->src[1]->type == lm_ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; - default: - return true; - } - - LM_GGML_UNUSED(backend); -} - -LM_GGML_CALL static bool lm_ggml_backend_cpu_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { - return lm_ggml_backend_buft_is_host(buft); - - LM_GGML_UNUSED(backend); -} - -static struct lm_ggml_backend_i cpu_backend_i = { - /* .get_name = */ lm_ggml_backend_cpu_name, +static const struct lm_ggml_backend_i lm_ggml_backend_cpu_i = { + /* .get_name = */ lm_ggml_backend_cpu_get_name, /* .free = */ lm_ggml_backend_cpu_free, - /* .get_default_buffer_type = */ lm_ggml_backend_cpu_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, @@ -851,14 +1022,8 @@ static struct lm_ggml_backend_i cpu_backend_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ lm_ggml_backend_cpu_graph_plan_compute, /* .graph_compute = */ lm_ggml_backend_cpu_graph_compute, - /* .supports_op = */ lm_ggml_backend_cpu_supports_op, - /* .supports_buft = */ lm_ggml_backend_cpu_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, }; static lm_ggml_guid_t lm_ggml_backend_cpu_guid(void) { @@ -867,32 +1032,34 @@ static lm_ggml_guid_t lm_ggml_backend_cpu_guid(void) { } lm_ggml_backend_t lm_ggml_backend_cpu_init(void) { - struct lm_ggml_backend_cpu_context * ctx = malloc(sizeof(struct lm_ggml_backend_cpu_context)); + struct lm_ggml_backend_cpu_context * ctx = new lm_ggml_backend_cpu_context; if (ctx == NULL) { return NULL; } ctx->n_threads = LM_GGML_DEFAULT_N_THREADS; + ctx->threadpool = NULL; ctx->work_data = NULL; ctx->work_size = 0; ctx->abort_callback = NULL; ctx->abort_callback_data = NULL; - lm_ggml_backend_t cpu_backend = malloc(sizeof(struct lm_ggml_backend)); + lm_ggml_backend_t cpu_backend = new lm_ggml_backend { + /* .guid = */ lm_ggml_backend_cpu_guid(), + /* .interface = */ lm_ggml_backend_cpu_i, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, + }; + if (cpu_backend == NULL) { - free(ctx); + delete ctx; return NULL; } - *cpu_backend = (struct lm_ggml_backend) { - /* .guid = */ lm_ggml_backend_cpu_guid(), - /* .interface = */ cpu_backend_i, - /* .context = */ ctx - }; return cpu_backend; } -LM_GGML_CALL bool lm_ggml_backend_is_cpu(lm_ggml_backend_t backend) { +bool lm_ggml_backend_is_cpu(lm_ggml_backend_t backend) { return backend != NULL && lm_ggml_guid_matches(backend->guid, lm_ggml_backend_cpu_guid()); } @@ -903,6 +1070,18 @@ void lm_ggml_backend_cpu_set_n_threads(lm_ggml_backend_t backend_cpu, int n_thre ctx->n_threads = n_threads; } +void lm_ggml_backend_cpu_set_threadpool(lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool) { + LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); + + struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + lm_ggml_threadpool_pause(ctx->threadpool); + } + ctx->threadpool = threadpool; +} + void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data) { LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); @@ -911,35 +1090,248 @@ void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_gg ctx->abort_callback_data = abort_callback_data; } -LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { +lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { LM_GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); - return lm_ggml_backend_buffer_init(lm_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); + return lm_ggml_backend_buffer_init(lm_ggml_backend_cpu_buffer_from_ptr_type(), lm_ggml_backend_cpu_buffer_from_ptr_i, ptr, size); +} + +// CPU backend - device + +struct lm_ggml_backend_cpu_device_context { + std::string description = "CPU"; + + lm_ggml_backend_cpu_device_context() { +#ifdef __APPLE__ + size_t len = 0; + if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) { + description.resize(len); + sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT + } +#elif defined(__linux__) + FILE * f = fopen("/proc/cpuinfo", "r"); + if (f) { + char buf[1024]; + while (fgets(buf, sizeof(buf), f)) { + if (strncmp(buf, "model name", 10) == 0) { + char * p = strchr(buf, ':'); + if (p) { + p++; + while (std::isspace(*p)) { + p++; + } + while (std::isspace(p[strlen(p) - 1])) { + p[strlen(p) - 1] = '\0'; + } + description = p; + break; + } + } + } + fclose(f); + } +#elif defined(_WIN32) + HKEY hKey; + if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, + TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), + 0, + KEY_READ, + &hKey) == ERROR_SUCCESS) { + DWORD cpu_brand_size = 0; + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + NULL, + &cpu_brand_size) == ERROR_SUCCESS) { + description.resize(cpu_brand_size); + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + (LPBYTE)&description[0], // NOLINT + &cpu_brand_size) == ERROR_SUCCESS) { + if (description.find('\0') != std::string::npos) { + description.resize(description.find('\0')); + } + } + } + RegCloseKey(hKey); + } +#endif + } +}; + +static const char * lm_ggml_backend_cpu_device_get_name(lm_ggml_backend_dev_t dev) { + return "CPU"; + + LM_GGML_UNUSED(dev); } -LM_GGML_CALL static lm_ggml_backend_t lm_ggml_backend_reg_cpu_init(const char * params, void * user_data) { +static const char * lm_ggml_backend_cpu_device_get_description(lm_ggml_backend_dev_t dev) { + struct lm_ggml_backend_cpu_device_context * ctx = (struct lm_ggml_backend_cpu_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void lm_ggml_backend_cpu_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + LM_GGML_UNUSED(dev); +} + +static enum lm_ggml_backend_dev_type lm_ggml_backend_cpu_device_get_type(lm_ggml_backend_dev_t dev) { + return LM_GGML_BACKEND_DEVICE_TYPE_CPU; + + LM_GGML_UNUSED(dev); +} + +static void lm_ggml_backend_cpu_device_get_props(lm_ggml_backend_dev_t dev, struct lm_ggml_backend_dev_props * props) { + props->name = lm_ggml_backend_cpu_device_get_name(dev); + props->description = lm_ggml_backend_cpu_device_get_description(dev); + props->type = lm_ggml_backend_cpu_device_get_type(dev); + lm_ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static lm_ggml_backend_t lm_ggml_backend_cpu_device_init_backend(lm_ggml_backend_dev_t dev, const char * params) { return lm_ggml_backend_cpu_init(); + LM_GGML_UNUSED(dev); LM_GGML_UNUSED(params); - LM_GGML_UNUSED(user_data); } -// multi-buffer buffer +static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_device_get_buffer_type(lm_ggml_backend_dev_t dev) { + return lm_ggml_backend_cpu_buffer_type(); -struct lm_ggml_backend_multi_buffer_context { - lm_ggml_backend_buffer_t * buffers; - size_t n_buffers; + LM_GGML_UNUSED(dev); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_device_buffer_from_host_ptr(lm_ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); + + LM_GGML_UNUSED(dev); + LM_GGML_UNUSED(max_tensor_size); +} + +static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { + switch (op->op) { + case LM_GGML_OP_CPY: + return + op->type != LM_GGML_TYPE_IQ2_XXS && + op->type != LM_GGML_TYPE_IQ2_XS && + op->type != LM_GGML_TYPE_IQ1_S && + op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float + case LM_GGML_OP_MUL_MAT: + return op->src[1]->type == LM_GGML_TYPE_F32 || op->src[1]->type == lm_ggml_get_type_traits(op->src[0]->type)->vec_dot_type; + case LM_GGML_OP_ROPE_BACK: + return op->src[2] == NULL && (op->op_params[2] & 4) == 0; + case LM_GGML_OP_IM2COL_BACK: + return op->src[0]->type == LM_GGML_TYPE_F32 && op->src[1]->type == LM_GGML_TYPE_F32; + case LM_GGML_OP_OUT_PROD: + return (op->src[0]->type == LM_GGML_TYPE_F32 || lm_ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == LM_GGML_TYPE_F32; + default: + return true; + } + + LM_GGML_UNUSED(dev); +} + +static bool lm_ggml_backend_cpu_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { + return lm_ggml_backend_buft_is_host(buft); + + LM_GGML_UNUSED(dev); +} + +static const struct lm_ggml_backend_device_i lm_ggml_backend_cpu_device_i = { + /* .get_name = */ lm_ggml_backend_cpu_device_get_name, + /* .get_description = */ lm_ggml_backend_cpu_device_get_description, + /* .get_memory = */ lm_ggml_backend_cpu_device_get_memory, + /* .get_type = */ lm_ggml_backend_cpu_device_get_type, + /* .get_props = */ lm_ggml_backend_cpu_device_get_props, + /* .init_backend = */ lm_ggml_backend_cpu_device_init_backend, + /* .get_buffer_type = */ lm_ggml_backend_cpu_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ lm_ggml_backend_cpu_device_buffer_from_host_ptr, + /* .supports_op = */ lm_ggml_backend_cpu_device_supports_op, + /* .supports_buft = */ lm_ggml_backend_cpu_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, }; -typedef struct lm_ggml_backend_multi_buffer_context * lm_ggml_backend_multi_buffer_context_t; +// CPU backend - backend (reg) -LM_GGML_CALL static const char * lm_ggml_backend_multi_buffer_get_name(lm_ggml_backend_buffer_t buffer) { - lm_ggml_backend_multi_buffer_context_t ctx = (lm_ggml_backend_multi_buffer_context_t) buffer->context; +static const char * lm_ggml_backend_cpu_reg_get_name(lm_ggml_backend_reg_t reg) { + return "CPU"; - return ctx->buffers[0]->iface.get_name(ctx->buffers[0]); + LM_GGML_UNUSED(reg); } -LM_GGML_CALL static void lm_ggml_backend_multi_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { - lm_ggml_backend_multi_buffer_context_t ctx = (lm_ggml_backend_multi_buffer_context_t) buffer->context; +static size_t lm_ggml_backend_cpu_reg_get_device_count(lm_ggml_backend_reg_t reg) { + return 1; + + LM_GGML_UNUSED(reg); +} + +static lm_ggml_backend_dev_t lm_ggml_backend_cpu_reg_get_device(lm_ggml_backend_reg_t reg, size_t index) { + LM_GGML_ASSERT(index == 0); + + static lm_ggml_backend_cpu_device_context ctx; + static lm_ggml_backend_device lm_ggml_backend_cpu_device = { + /* .iface = */ lm_ggml_backend_cpu_device_i, + /* .reg = */ reg, + /* .context = */ &ctx, + }; + + return &lm_ggml_backend_cpu_device; +} + +static void * lm_ggml_backend_cpu_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "lm_ggml_backend_set_n_threads") == 0) { + return (void *)lm_ggml_backend_cpu_set_n_threads; + } + if (strcmp(name, "lm_ggml_backend_dev_get_extra_bufts") == 0) { + return (void *)lm_ggml_backend_cpu_get_extra_bufts; + } + + return NULL; + + LM_GGML_UNUSED(reg); +} + +static const struct lm_ggml_backend_reg_i lm_ggml_backend_cpu_reg_i = { + /* .get_name = */ lm_ggml_backend_cpu_reg_get_name, + /* .get_device_count = */ lm_ggml_backend_cpu_reg_get_device_count, + /* .get_device = */ lm_ggml_backend_cpu_reg_get_device, + /* .get_proc_address = */ lm_ggml_backend_cpu_get_proc_address, +}; + +lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void) { + static struct lm_ggml_backend_reg lm_ggml_backend_cpu_reg = { + /* .iface = */ lm_ggml_backend_cpu_reg_i, + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_reg; +} + +// multi-buffer buffer + +struct lm_ggml_backend_multi_buffer_context { + lm_ggml_backend_buffer_t * buffers; + size_t n_buffers; +}; + +static void lm_ggml_backend_multi_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { + lm_ggml_backend_multi_buffer_context * ctx = (lm_ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { lm_ggml_backend_buffer_free(ctx->buffers[i]); } @@ -948,31 +1340,27 @@ LM_GGML_CALL static void lm_ggml_backend_multi_buffer_free_buffer(lm_ggml_backen free(ctx); } -LM_GGML_CALL static void lm_ggml_backend_multi_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { - lm_ggml_backend_multi_buffer_context_t ctx = (lm_ggml_backend_multi_buffer_context_t) buffer->context; +static void lm_ggml_backend_multi_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { + lm_ggml_backend_multi_buffer_context * ctx = (lm_ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { lm_ggml_backend_buffer_clear(ctx->buffers[i], value); } } -static struct lm_ggml_backend_buffer_i lm_ggml_backend_multi_buffer_context_interface(void) { - static struct lm_ggml_backend_buffer_i multi_backend_buffer_i = { - /* .get_name = */ lm_ggml_backend_multi_buffer_get_name, - /* .free_buffer = */ lm_ggml_backend_multi_buffer_free_buffer, - /* .get_base = */ NULL, - /* .init_tensor = */ NULL, - /* .set_tensor = */ NULL, - /* .get_tensor = */ NULL, - /* .cpy_tensor = */ NULL, - /* .clear = */ lm_ggml_backend_multi_buffer_clear, - /* .reset = */ NULL, - }; - - return multi_backend_buffer_i; -} +static const struct lm_ggml_backend_buffer_i lm_ggml_backend_multi_buffer_i = { + /* .free_buffer = */ lm_ggml_backend_multi_buffer_free_buffer, + /* .get_base = */ NULL, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ NULL, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ lm_ggml_backend_multi_buffer_clear, + /* .reset = */ NULL, +}; -LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers) { - lm_ggml_backend_multi_buffer_context_t ctx = (lm_ggml_backend_multi_buffer_context_t) malloc(sizeof(struct lm_ggml_backend_multi_buffer_context)); +lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers) { + lm_ggml_backend_multi_buffer_context * ctx = (lm_ggml_backend_multi_buffer_context *) malloc(sizeof(struct lm_ggml_backend_multi_buffer_context)); ctx->n_buffers = n_buffers; ctx->buffers = (lm_ggml_backend_buffer_t *) malloc(n_buffers * sizeof(lm_ggml_backend_buffer_t)); @@ -984,16 +1372,16 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer( total_size += lm_ggml_backend_buffer_get_size(buffers[i]); } - return lm_ggml_backend_buffer_init(buffers[0]->buft, lm_ggml_backend_multi_buffer_context_interface(), ctx, total_size); + return lm_ggml_backend_buffer_init(buffers[0]->buft, lm_ggml_backend_multi_buffer_i, ctx, total_size); } -LM_GGML_CALL bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == lm_ggml_backend_multi_buffer_get_name; +bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == lm_ggml_backend_multi_buffer_free_buffer; } -LM_GGML_CALL void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage) { +void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage) { LM_GGML_ASSERT(lm_ggml_backend_buffer_is_multi_buffer(buffer)); - lm_ggml_backend_multi_buffer_context_t ctx = (lm_ggml_backend_multi_buffer_context_t) buffer->context; + lm_ggml_backend_multi_buffer_context * ctx = (lm_ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { lm_ggml_backend_buffer_set_usage(ctx->buffers[i], usage); } @@ -1080,7 +1468,7 @@ struct lm_ggml_backend_sched { char * context_buffer; size_t context_buffer_size; - bool debug; + int debug; }; #define hash_id(tensor) lm_ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1113,7 +1501,7 @@ static int lm_ggml_backend_sched_backend_from_buffer(lm_ggml_backend_sched_t sch } #ifndef NDEBUG - fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", + LM_GGML_LOG_DEBUG("%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", __func__, lm_ggml_op_desc(tensor), lm_ggml_backend_buffer_name(buffer), tensor->name); #endif @@ -1150,6 +1538,11 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch } } + if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { + // since the tensor is pre-allocated, it cannot be moved to another backend + LM_GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); + } + // graph input if (tensor->flags & LM_GGML_TENSOR_FLAG_INPUT) { cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU) @@ -1163,7 +1556,9 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch if (src == NULL) { continue; } - if (src->buffer != NULL && src->buffer->usage == LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + // skip ROPE since the rope freqs tensor is too small to choose a backend based on it + // not an ideal solution + if (tensor->op != LM_GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op if (src_backend_id == sched->n_backends - 1) { @@ -1197,32 +1592,34 @@ static void lm_ggml_backend_sched_print_assignments(lm_ggml_backend_sched_t sche for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { lm_ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; - fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, lm_ggml_backend_name(split_backend), + LM_GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, lm_ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { - fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + LM_GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(lm_ggml_nbytes(sched->splits[cur_split].inputs[j]))); } - fprintf(stderr, "\n"); + LM_GGML_LOG_DEBUG("\n"); cur_split++; } struct lm_ggml_tensor * node = graph->nodes[i]; if (lm_ggml_is_view_op(node->op)) { continue; } - lm_ggml_backend_t tensor_backend = lm_ggml_backend_sched_get_tensor_backend(sched, node); - fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, lm_ggml_op_name(node->op), node->name, - fmt_size(lm_ggml_nbytes(node)), tensor_backend ? lm_ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); - for (int j = 0; j < LM_GGML_MAX_SRC; j++) { - struct lm_ggml_tensor * src = node->src[j]; - if (src == NULL) { - continue; + if (sched->debug > 1) { + lm_ggml_backend_t tensor_backend = lm_ggml_backend_sched_get_tensor_backend(sched, node); + LM_GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, lm_ggml_op_name(node->op), node->name, + fmt_size(lm_ggml_nbytes(node)), tensor_backend ? lm_ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); + for (int j = 0; j < LM_GGML_MAX_SRC; j++) { + struct lm_ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + lm_ggml_backend_t src_backend = lm_ggml_backend_sched_get_tensor_backend(sched, src); + LM_GGML_LOG_DEBUG(" %20.20s (%5.5s) [%5.5s %8.8s]", src->name, + fmt_size(lm_ggml_nbytes(src)), src_backend ? lm_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); } - lm_ggml_backend_t src_backend = lm_ggml_backend_sched_get_tensor_backend(sched, src); - fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name, - fmt_size(lm_ggml_nbytes(src)), src_backend ? lm_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); + LM_GGML_LOG_DEBUG("\n"); } - fprintf(stderr, "\n"); } } @@ -1514,11 +1911,11 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str if (src == NULL) { continue; } - // check if a weight is on a different backend + // check if a weight is on a different and incompatible backend // by starting a new split, the memory of the previously offloaded weights can be reused if (src->buffer != NULL && src->buffer->usage == LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = tensor_backend_id(src); - if (src_backend_id != cur_backend_id) { + if (src_backend_id != cur_backend_id && !lm_ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) { need_new_split = true; break; } @@ -1530,7 +1927,6 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str int src_backend_id = sched->hv_tensor_backend_ids[id]; bool supported = lm_ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) { - //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); need_new_split = true; break; } @@ -1543,7 +1939,8 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str i_split++; if (i_split >= sched->splits_capacity) { sched->splits_capacity *= 2; - sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct lm_ggml_backend_sched_split)); + sched->splits = (lm_ggml_backend_sched_split *) + realloc(sched->splits, sched->splits_capacity * sizeof(struct lm_ggml_backend_sched_split)); LM_GGML_ASSERT(sched->splits != NULL); } split = &sched->splits[i_split]; @@ -1629,11 +2026,11 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str sched->prev_leaf_backend_ids = tmp; } - int graph_size = graph->n_nodes + sched->n_splits*LM_GGML_SCHED_MAX_SPLIT_INPUTS*2; + int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*LM_GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; if (sched->graph.size < graph_size) { sched->graph.size = graph_size; - sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct lm_ggml_tensor *)); - sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct lm_ggml_tensor *)); + sched->graph.nodes = (lm_ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct lm_ggml_tensor *)); + sched->graph.leafs = (lm_ggml_tensor **) realloc(sched->graph.leafs, graph_size * sizeof(struct lm_ggml_tensor *)); LM_GGML_ASSERT(sched->graph.nodes != NULL); LM_GGML_ASSERT(sched->graph.leafs != NULL); } @@ -1681,6 +2078,7 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str for (int c = 0; c < sched->n_copies; c++) { struct lm_ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; } } @@ -1694,6 +2092,7 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str for (int c = 0; c < sched->n_copies; c++) { struct lm_ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; } } @@ -1704,6 +2103,7 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_leafs; i++) { struct lm_ggml_tensor * leaf = graph->leafs[i]; sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf); + assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } } @@ -1732,11 +2132,11 @@ static bool lm_ggml_backend_sched_alloc_splits(lm_ggml_backend_sched_t sched) { // the re-allocation may cause the split inputs to be moved to a different address lm_ggml_backend_sched_synchronize(sched); #ifndef NDEBUG - fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); + LM_GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif lm_ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!lm_ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { - fprintf(stderr, "%s: failed to allocate graph\n", __func__); + LM_GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); return false; } } @@ -1829,7 +2229,7 @@ static enum lm_ggml_status lm_ggml_backend_sched_compute_splits(lm_ggml_backend_ // record the event of this copy if (split->n_inputs > 0) { if (sched->events[split_backend_id][sched->cur_copy] != NULL) { - lm_ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + lm_ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy], split_backend); } } } @@ -1849,39 +2249,41 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new( LM_GGML_ASSERT(n_backends <= LM_GGML_SCHED_MAX_BACKENDS); LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU - struct lm_ggml_backend_sched * sched = calloc(1, sizeof(struct lm_ggml_backend_sched)); + struct lm_ggml_backend_sched * sched = (lm_ggml_backend_sched *) calloc(1, sizeof(struct lm_ggml_backend_sched)); - sched->debug = getenv("LM_GGML_SCHED_DEBUG") != NULL; + const char * LM_GGML_SCHED_DEBUG = getenv("LM_GGML_SCHED_DEBUG"); + sched->debug = LM_GGML_SCHED_DEBUG ? atoi(LM_GGML_SCHED_DEBUG) : 0; sched->n_backends = n_backends; sched->n_copies = parallel ? LM_GGML_SCHED_MAX_COPIES : 1; // initialize hash table // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead) sched->hash_set = lm_ggml_hash_set_new(graph_size); - sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); - sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct lm_ggml_tensor *)); + sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = (lm_ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct lm_ggml_tensor *)); const size_t lm_ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph const size_t nodes_size = graph_size + lm_ggml_sched_max_splits*LM_GGML_SCHED_MAX_SPLIT_INPUTS*2; - sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); - sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); - sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); - sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); sched->context_buffer_size = lm_ggml_sched_max_splits*LM_GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct lm_ggml_tensor) + lm_ggml_graph_overhead_custom(graph_size, false); - sched->context_buffer = malloc(sched->context_buffer_size); + sched->context_buffer = (char *) malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; - sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0])); + sched->splits = (lm_ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0])); sched->splits_capacity = initial_splits_capacity; for (int b = 0; b < n_backends; b++) { sched->backends[b] = backends[b]; sched->bufts[b] = bufts ? bufts[b] : lm_ggml_backend_get_default_buffer_type(backends[b]); LM_GGML_ASSERT(lm_ggml_backend_supports_buft(backends[b], sched->bufts[b])); + if (sched->n_copies > 1) { for (int c = 0; c < sched->n_copies; c++) { - sched->events[b][c] = lm_ggml_backend_event_new(backends[b]); + sched->events[b][c] = lm_ggml_backend_event_new(backends[b]->device); } } } @@ -2117,8 +2519,8 @@ static void graph_copy_init_tensor(struct lm_ggml_hash_set * hash_set, struct lm struct lm_ggml_backend_graph_copy lm_ggml_backend_graph_copy(lm_ggml_backend_t backend, struct lm_ggml_cgraph * graph) { struct lm_ggml_hash_set hash_set = lm_ggml_hash_set_new(graph->visited_hash_set.size); - struct lm_ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT - bool * node_init = calloc(hash_set.size, sizeof(node_init[0])); + struct lm_ggml_tensor ** node_copies = (lm_ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT + bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); struct lm_ggml_init_params params = { /* .mem_size = */ lm_ggml_tensor_overhead()*hash_set.size + lm_ggml_graph_overhead_custom(graph->size, false), @@ -2130,13 +2532,13 @@ struct lm_ggml_backend_graph_copy lm_ggml_backend_graph_copy(lm_ggml_backend_t b struct lm_ggml_context * ctx_unallocated = lm_ggml_init(params); if (ctx_allocated == NULL || ctx_unallocated == NULL) { - fprintf(stderr, "failed to allocate context for graph copy\n"); + LM_GGML_LOG_ERROR("%s: failed to allocate context for graph copy\n", __func__); lm_ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); lm_ggml_free(ctx_allocated); lm_ggml_free(ctx_unallocated); - return (struct lm_ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2153,13 +2555,13 @@ struct lm_ggml_backend_graph_copy lm_ggml_backend_graph_copy(lm_ggml_backend_t b // allocate nodes lm_ggml_backend_buffer_t buffer = lm_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); if (buffer == NULL) { - fprintf(stderr, "failed to allocate buffer for graph copy\n"); + LM_GGML_LOG_ERROR("%s: failed to allocate buffer for graph copy\n", __func__); lm_ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); lm_ggml_free(ctx_allocated); lm_ggml_free(ctx_unallocated); - return (struct lm_ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2188,7 +2590,7 @@ struct lm_ggml_backend_graph_copy lm_ggml_backend_graph_copy(lm_ggml_backend_t b free(node_copies); free(node_init); - return (struct lm_ggml_backend_graph_copy) { + return { /* .buffer = */ buffer, /* .ctx_allocated = */ ctx_allocated, /* .ctx_unallocated = */ ctx_unallocated, diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index e0177c2..e85bdca 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -12,43 +12,52 @@ extern "C" { typedef struct lm_ggml_backend_event * lm_ggml_backend_event_t; typedef struct lm_ggml_backend * lm_ggml_backend_t; typedef void * lm_ggml_backend_graph_plan_t; + typedef struct lm_ggml_backend_reg * lm_ggml_backend_reg_t; + typedef struct lm_ggml_backend_device * lm_ggml_backend_dev_t; + // - // Backend buffer + // Backend buffer type // - // buffer type - LM_GGML_API const char * lm_ggml_backend_buft_name (lm_ggml_backend_buffer_type_t buft); - LM_GGML_API LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer (lm_ggml_backend_buffer_type_t buft, size_t size); - LM_GGML_API size_t lm_ggml_backend_buft_get_alignment (lm_ggml_backend_buffer_type_t buft); - LM_GGML_API size_t lm_ggml_backend_buft_get_max_size (lm_ggml_backend_buffer_type_t buft); - LM_GGML_API LM_GGML_CALL size_t lm_ggml_backend_buft_get_alloc_size (lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_backend_buft_is_host (lm_ggml_backend_buffer_type_t buft); + LM_GGML_API const char * lm_ggml_backend_buft_name (lm_ggml_backend_buffer_type_t buft); + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer (lm_ggml_backend_buffer_type_t buft, size_t size); + LM_GGML_API size_t lm_ggml_backend_buft_get_alignment (lm_ggml_backend_buffer_type_t buft); + LM_GGML_API size_t lm_ggml_backend_buft_get_max_size (lm_ggml_backend_buffer_type_t buft); + LM_GGML_API size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_backend_buft_is_host (lm_ggml_backend_buffer_type_t buft); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_buft_get_device (lm_ggml_backend_buffer_type_t buft); + + // + // Backend buffer + // - // buffer enum lm_ggml_backend_buffer_usage { LM_GGML_BACKEND_BUFFER_USAGE_ANY = 0, LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, LM_GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, }; - LM_GGML_API const char * lm_ggml_backend_buffer_name (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_free (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void * lm_ggml_backend_buffer_get_base (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_size (lm_ggml_backend_buffer_t buffer); - LM_GGML_API LM_GGML_CALL void lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); - LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value); - LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); - LM_GGML_API enum lm_ggml_backend_buffer_usage lm_ggml_backend_buffer_get_usage (lm_ggml_backend_buffer_t buffer); - LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_buffer_get_type (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_reset (lm_ggml_backend_buffer_t buffer); + LM_GGML_API const char * lm_ggml_backend_buffer_name (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_free (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void * lm_ggml_backend_buffer_get_base (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_size (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value); + LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); + LM_GGML_API enum lm_ggml_backend_buffer_usage lm_ggml_backend_buffer_get_usage (lm_ggml_backend_buffer_t buffer); + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_buffer_get_type (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_reset (lm_ggml_backend_buffer_t buffer); + + // tensor copy between different backends + LM_GGML_API void lm_ggml_backend_tensor_copy(struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); // - // Backend + // Backend (stream) // LM_GGML_API lm_ggml_guid_t lm_ggml_backend_guid(lm_ggml_backend_t backend); @@ -63,8 +72,10 @@ extern "C" { LM_GGML_API void lm_ggml_backend_tensor_set_async(lm_ggml_backend_t backend, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); LM_GGML_API void lm_ggml_backend_tensor_get_async(lm_ggml_backend_t backend, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); - LM_GGML_API LM_GGML_CALL void lm_ggml_backend_tensor_set( struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - LM_GGML_API LM_GGML_CALL void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); + // "offset" refers to the offset of the tensor data for setting/getting data + LM_GGML_API void lm_ggml_backend_tensor_set( struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + LM_GGML_API void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); + LM_GGML_API void lm_ggml_backend_tensor_memset( struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); LM_GGML_API void lm_ggml_backend_synchronize(lm_ggml_backend_t backend); @@ -74,64 +85,126 @@ extern "C" { LM_GGML_API enum lm_ggml_status lm_ggml_backend_graph_plan_compute (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); LM_GGML_API enum lm_ggml_status lm_ggml_backend_graph_compute (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); LM_GGML_API enum lm_ggml_status lm_ggml_backend_graph_compute_async(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); + + // NOTE: will be removed, use device version instead LM_GGML_API bool lm_ggml_backend_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); LM_GGML_API bool lm_ggml_backend_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft); LM_GGML_API bool lm_ggml_backend_offload_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); - // tensor copy between different backends - LM_GGML_API void lm_ggml_backend_tensor_copy(struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); - // asynchronous copy // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported LM_GGML_API void lm_ggml_backend_tensor_copy_async(lm_ggml_backend_t backend_src, lm_ggml_backend_t backend_dst, struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); - // events - LM_GGML_API lm_ggml_backend_event_t lm_ggml_backend_event_new (lm_ggml_backend_t backend); - LM_GGML_API void lm_ggml_backend_event_free (lm_ggml_backend_event_t event); - LM_GGML_API void lm_ggml_backend_event_record (lm_ggml_backend_event_t event); - LM_GGML_API void lm_ggml_backend_event_synchronize(lm_ggml_backend_event_t event); - LM_GGML_API void lm_ggml_backend_event_wait (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_get_device(lm_ggml_backend_t backend); // - // CPU backend + // Events // - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_cpu_init(void); + LM_GGML_API lm_ggml_backend_event_t lm_ggml_backend_event_new(lm_ggml_backend_dev_t device); + LM_GGML_API void lm_ggml_backend_event_free(lm_ggml_backend_event_t event); + LM_GGML_API void lm_ggml_backend_event_record(lm_ggml_backend_event_t event, lm_ggml_backend_t backend); + LM_GGML_API void lm_ggml_backend_event_synchronize(lm_ggml_backend_event_t event); + LM_GGML_API void lm_ggml_backend_event_wait(lm_ggml_backend_t backend, lm_ggml_backend_event_t event); - LM_GGML_API LM_GGML_CALL bool lm_ggml_backend_is_cpu (lm_ggml_backend_t backend); - LM_GGML_API void lm_ggml_backend_cpu_set_n_threads (lm_ggml_backend_t backend_cpu, int n_threads); - LM_GGML_API void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data); + // + // Backend device + // - // Create a backend buffer from an existing pointer - LM_GGML_API LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + enum lm_ggml_backend_dev_type { + // CPU device using system memory + LM_GGML_BACKEND_DEVICE_TYPE_CPU, + // GPU device using dedicated memory + LM_GGML_BACKEND_DEVICE_TYPE_GPU, + // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) + LM_GGML_BACKEND_DEVICE_TYPE_ACCEL + }; - LM_GGML_API LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void); + // functionality supported by the device + struct lm_ggml_backend_dev_caps { + // asynchronous operations + bool async; + // pinned host buffer + bool host_buffer; + // creating buffers from host ptr + bool buffer_from_host_ptr; + // event synchronization + bool events; + }; -#ifdef LM_GGML_USE_CPU_HBM - LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void); -#endif + // all the device properties + struct lm_ggml_backend_dev_props { + const char * name; + const char * description; + size_t memory_free; + size_t memory_total; + enum lm_ggml_backend_dev_type type; + struct lm_ggml_backend_dev_caps caps; + }; + + LM_GGML_API const char * lm_ggml_backend_dev_name(lm_ggml_backend_dev_t device); + LM_GGML_API const char * lm_ggml_backend_dev_description(lm_ggml_backend_dev_t device); + LM_GGML_API void lm_ggml_backend_dev_memory(lm_ggml_backend_dev_t device, size_t * free, size_t * total); + LM_GGML_API enum lm_ggml_backend_dev_type lm_ggml_backend_dev_type(lm_ggml_backend_dev_t device); + LM_GGML_API void lm_ggml_backend_dev_get_props(lm_ggml_backend_dev_t device, struct lm_ggml_backend_dev_props * props); + LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_dev_backend_reg(lm_ggml_backend_dev_t device); + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_dev_init(lm_ggml_backend_dev_t device, const char * params); + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_dev_buffer_type(lm_ggml_backend_dev_t device); + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_dev_host_buffer_type(lm_ggml_backend_dev_t device); + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_dev_buffer_from_host_ptr(lm_ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); + + LM_GGML_API bool lm_ggml_backend_dev_supports_op(lm_ggml_backend_dev_t device, const struct lm_ggml_tensor * op); + LM_GGML_API bool lm_ggml_backend_dev_supports_buft(lm_ggml_backend_dev_t device, lm_ggml_backend_buffer_type_t buft); + LM_GGML_API bool lm_ggml_backend_dev_offload_op(lm_ggml_backend_dev_t device, const struct lm_ggml_tensor * op); // - // Backend registry + // Backend (reg) // - // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way + LM_GGML_API const char * lm_ggml_backend_reg_name(lm_ggml_backend_reg_t reg); + LM_GGML_API size_t lm_ggml_backend_reg_dev_count(lm_ggml_backend_reg_t reg); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_reg_dev_get(lm_ggml_backend_reg_t reg, size_t index); + LM_GGML_API void * lm_ggml_backend_reg_get_proc_address(lm_ggml_backend_reg_t reg, const char * name); - LM_GGML_API size_t lm_ggml_backend_reg_get_count(void); - LM_GGML_API size_t lm_ggml_backend_reg_find_by_name(const char * name); - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional) - LM_GGML_API const char * lm_ggml_backend_reg_get_name(size_t i); - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific - LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_reg_get_default_buffer_type(size_t i); - LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_reg_alloc_buffer(size_t i, size_t size); + // Common functions that may be obtained using lm_ggml_backend_reg_get_proc_address + + // Split buffer type for tensor parallelism + typedef lm_ggml_backend_buffer_type_t (*lm_ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); + // Set the number of threads for the backend + typedef void (*lm_ggml_backend_set_n_threads_t)(lm_ggml_backend_t backend, int n_threads); + // Get additional buffer types provided by the device (returns a NULL-terminated array) + typedef lm_ggml_backend_buffer_type_t * (*lm_ggml_backend_dev_get_extra_bufts_t)(lm_ggml_backend_dev_t device); + + // + // Backend registry + // + + // Backend (reg) enumeration + LM_GGML_API size_t lm_ggml_backend_reg_count(void); + LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_reg_get(size_t index); + LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_reg_by_name(const char * name); + + // Device enumeration + LM_GGML_API size_t lm_ggml_backend_dev_count(void); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_dev_get(size_t index); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_dev_by_name(const char * name); + LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type type); + + // Direct backend (stream) initialization + // = lm_ggml_backend_dev_init(lm_ggml_backend_dev_by_name(name), params) + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * params); + // = lm_ggml_backend_dev_init(lm_ggml_backend_dev_by_type(type), params) + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_init_by_type(enum lm_ggml_backend_dev_type type, const char * params); + // = lm_ggml_backend_dev_init(lm_ggml_backend_dev_by_type(GPU) OR lm_ggml_backend_dev_by_type(CPU), NULL) + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_init_best(void); // // Backend scheduler // - // The backend scheduler allows for multiple backends to be used together + // The backend scheduler allows for multiple backend devices to be used together // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends // The backends are selected based on: // - the backend that supports the operation @@ -166,9 +239,9 @@ extern "C" { } */ - struct lm_ggml_backend_sched; typedef struct lm_ggml_backend_sched * lm_ggml_backend_sched_t; + // Evaluation callback for each node in the graph (set with lm_ggml_backend_sched_set_eval_callback) // when ask == true, the scheduler wants to know if the user wants to observe this node // this allows the scheduler to batch nodes together in order to evaluate them in a single call // @@ -182,7 +255,7 @@ extern "C" { LM_GGML_API void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph - LM_GGML_API bool lm_ggml_backend_sched_reserve(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * measure_graph); + LM_GGML_API bool lm_ggml_backend_sched_reserve(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * measure_graph); // returns success LM_GGML_API int lm_ggml_backend_sched_get_n_backends(lm_ggml_backend_sched_t sched); LM_GGML_API lm_ggml_backend_t lm_ggml_backend_sched_get_backend(lm_ggml_backend_sched_t sched, int i); @@ -197,7 +270,7 @@ extern "C" { LM_GGML_API lm_ggml_backend_t lm_ggml_backend_sched_get_tensor_backend(lm_ggml_backend_sched_t sched, struct lm_ggml_tensor * node); // Allocate and compute graph on the backend scheduler - LM_GGML_API bool lm_ggml_backend_sched_alloc_graph(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph); + LM_GGML_API bool lm_ggml_backend_sched_alloc_graph(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph); // returns success LM_GGML_API enum lm_ggml_status lm_ggml_backend_sched_graph_compute(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph); LM_GGML_API enum lm_ggml_status lm_ggml_backend_sched_graph_compute_async(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph); LM_GGML_API void lm_ggml_backend_sched_synchronize(lm_ggml_backend_sched_t sched); @@ -223,7 +296,7 @@ extern "C" { LM_GGML_API struct lm_ggml_backend_graph_copy lm_ggml_backend_graph_copy(lm_ggml_backend_t backend, struct lm_ggml_cgraph * graph); LM_GGML_API void lm_ggml_backend_graph_copy_free(struct lm_ggml_backend_graph_copy copy); - typedef bool (*LM_GGML_CALL lm_ggml_backend_eval_callback)(int node_index, struct lm_ggml_tensor * t1, struct lm_ggml_tensor * t2, void * user_data); + typedef bool (*lm_ggml_backend_eval_callback)(int node_index, struct lm_ggml_tensor * t1, struct lm_ggml_tensor * t2, void * user_data); // Compare the output of two backends LM_GGML_API bool lm_ggml_backend_compare_graph_backend(lm_ggml_backend_t backend1, lm_ggml_backend_t backend2, struct lm_ggml_cgraph * graph, lm_ggml_backend_eval_callback callback, void * user_data); @@ -232,6 +305,26 @@ extern "C" { LM_GGML_API void lm_ggml_backend_tensor_alloc(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, void * addr); LM_GGML_API void lm_ggml_backend_view_init(struct lm_ggml_tensor * tensor); + // + // CPU backend + // + + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_cpu_init(void); + + LM_GGML_API bool lm_ggml_backend_is_cpu (lm_ggml_backend_t backend); + LM_GGML_API void lm_ggml_backend_cpu_set_n_threads (lm_ggml_backend_t backend_cpu, int n_threads); + LM_GGML_API void lm_ggml_backend_cpu_set_threadpool (lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool); + LM_GGML_API void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data); + + // Create a backend buffer from an existing pointer + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void); + + LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void); + +#ifdef LM_GGML_USE_CPU_HBM + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void); +#endif #ifdef __cplusplus } diff --git a/cpp/ggml-common.h b/cpp/ggml-common.h index 617e25a..fc17c61 100644 --- a/cpp/ggml-common.h +++ b/cpp/ggml-common.h @@ -227,6 +227,25 @@ typedef struct { } block_q8_0x8; static_assert(sizeof(block_q8_0x8) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); +// +// Ternary quantization +// + +// 1.6875 bpw +typedef struct { + uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256) + uint8_t qh[QK_K/64]; // 4 elements per byte + lm_ggml_half d; +} block_tq1_0; +static_assert(sizeof(block_tq1_0) == sizeof(lm_ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding"); + +// 2.0625 bpw +typedef struct { + uint8_t qs[QK_K/4]; // 2 bits per element + lm_ggml_half d; +} block_tq2_0; +static_assert(sizeof(block_tq2_0) == sizeof(lm_ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); + // // Super-block quantization structures // @@ -361,6 +380,7 @@ typedef struct { } block_iq3_s; static_assert(sizeof(block_iq3_s) == sizeof(lm_ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); +// 1.5625 bpw typedef struct { lm_ggml_half d; uint8_t qs[QK_K/8]; diff --git a/cpp/ggml-cpp.h b/cpp/ggml-cpp.h new file mode 100644 index 0000000..7e7c8f5 --- /dev/null +++ b/cpp/ggml-cpp.h @@ -0,0 +1,38 @@ +#pragma once + +#ifndef __cplusplus +#error "This header is for C++ only" +#endif + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include + +// Smart pointers for ggml types + +// ggml + +struct lm_ggml_context_deleter { void operator()(lm_ggml_context * ctx) { lm_ggml_free(ctx); } }; +struct lm_gguf_context_deleter { void operator()(lm_gguf_context * ctx) { lm_gguf_free(ctx); } }; + +typedef std::unique_ptr lm_ggml_context_ptr; +typedef std::unique_ptr lm_gguf_context_ptr; + +// ggml-alloc + +struct lm_ggml_gallocr_deleter { void operator()(lm_ggml_gallocr_t galloc) { lm_ggml_gallocr_free(galloc); } }; + +typedef std::unique_ptr lm_ggml_gallocr_ptr; + +// ggml-backend + +struct lm_ggml_backend_deleter { void operator()(lm_ggml_backend_t backend) { lm_ggml_backend_free(backend); } }; +struct lm_ggml_backend_buffer_deleter { void operator()(lm_ggml_backend_buffer_t buffer) { lm_ggml_backend_buffer_free(buffer); } }; +struct lm_ggml_backend_event_deleter { void operator()(lm_ggml_backend_event_t event) { lm_ggml_backend_event_free(event); } }; +struct lm_ggml_backend_sched_deleter { void operator()(lm_ggml_backend_sched_t sched) { lm_ggml_backend_sched_free(sched); } }; + +typedef std::unique_ptr lm_ggml_backend_ptr; +typedef std::unique_ptr lm_ggml_backend_buffer_ptr; +typedef std::unique_ptr lm_ggml_backend_event_ptr; +typedef std::unique_ptr lm_ggml_backend_sched_ptr; diff --git a/cpp/ggml-cpu-impl.h b/cpp/ggml-cpu-impl.h new file mode 100644 index 0000000..760deb2 --- /dev/null +++ b/cpp/ggml-cpu-impl.h @@ -0,0 +1,614 @@ +#pragma once + +// GGML CPU internal header + +#include "ggml.h" +#include "ggml-impl.h" +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ +//#include +#include +#include // memcpy +#include // fabsf + + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_MSC_VER) + +#define m512bh(p) p +#define m512i(p) p + +#else + +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) + +#endif + +/** + * Converts brain16 to float32. + * + * The bfloat16 floating point format has the following structure: + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───┐ + * 0b0000000000000000 brain16 + * + * Since bf16 has the same number of exponent bits as a 32bit float, + * encoding and decoding numbers becomes relatively straightforward. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───────────────────┐ + * 0b00000000000000000000000000000000 IEEE binary32 + * + * For comparison, the standard fp16 format has fewer exponent bits. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌─┴─┐┌─┴──────┐ + * 0b0000000000000000 IEEE binary16 + * + * @see IEEE 754-2008 + */ +static inline float lm_ggml_compute_bf16_to_fp32(lm_ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +/** + * Converts float32 to brain16. + * + * This is binary identical with Google Brain float conversion. + * Floats shall round to nearest even, and NANs shall be quiet. + * Subnormals aren't flushed to zero, except perhaps when used. + * This code should vectorize nicely if using modern compilers. + */ +static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) { + lm_ggml_bf16_t h; + union { + float f; + uint32_t i; + } u; + u.f = s; + if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ + h.bits = (u.i >> 16) | 64; /* force to quiet */ + return h; + } + h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; + return h; +} + +#define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x) +#define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x) + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __SSE3__ +#define __SSE3__ +#endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif +#endif + +#if defined(__ARM_FEATURE_SVE) +#include +#include +#endif + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#ifdef _MSC_VER + +typedef uint16_t lm_ggml_fp16_internal_t; + +#define lm_ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +typedef __fp16 lm_ggml_fp16_internal_t; + +#define lm_ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif // _MSC_VER + +#if !defined(__aarch64__) + +// 32-bit ARM compatibility + +// vaddlvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddlvq_s16(int16x8_t v) { + int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v))); + return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct lm_ggml_int16x8x2_t { + int16x8_t val[2]; +} lm_ggml_int16x8x2_t; + +inline static lm_ggml_int16x8x2_t lm_ggml_vld1q_s16_x2(const int16_t * ptr) { + lm_ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct lm_ggml_uint8x16x2_t { + uint8x16_t val[2]; +} lm_ggml_uint8x16x2_t; + +inline static lm_ggml_uint8x16x2_t lm_ggml_vld1q_u8_x2(const uint8_t * ptr) { + lm_ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct lm_ggml_uint8x16x4_t { + uint8x16_t val[4]; +} lm_ggml_uint8x16x4_t; + +inline static lm_ggml_uint8x16x4_t lm_ggml_vld1q_u8_x4(const uint8_t * ptr) { + lm_ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct lm_ggml_int8x16x2_t { + int8x16_t val[2]; +} lm_ggml_int8x16x2_t; + +inline static lm_ggml_int8x16x2_t lm_ggml_vld1q_s8_x2(const int8_t * ptr) { + lm_ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct lm_ggml_int8x16x4_t { + int8x16_t val[4]; +} lm_ggml_int8x16x4_t; + +inline static lm_ggml_int8x16x4_t lm_ggml_vld1q_s8_x4(const int8_t * ptr) { + lm_ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t lm_ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t lm_ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define lm_ggml_int16x8x2_t int16x8x2_t +#define lm_ggml_uint8x16x2_t uint8x16x2_t +#define lm_ggml_uint8x16x4_t uint8x16x4_t +#define lm_ggml_int8x16x2_t int8x16x2_t +#define lm_ggml_int8x16x4_t int8x16x4_t + +#define lm_ggml_vld1q_s16_x2 vld1q_s16_x2 +#define lm_ggml_vld1q_u8_x2 vld1q_u8_x2 +#define lm_ggml_vld1q_u8_x4 vld1q_u8_x4 +#define lm_ggml_vld1q_s8_x2 vld1q_s8_x2 +#define lm_ggml_vld1q_s8_x4 vld1q_s8_x4 +#define lm_ggml_vqtbl1q_s8 vqtbl1q_s8 +#define lm_ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif // !defined(__aarch64__) + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t lm_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define lm_ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif // !defined(__ARM_FEATURE_DOTPROD) + +#endif // defined(__ARM_NEON) + +#if defined(__ARM_NEON) && !defined(_MSC_VER) + +#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) +#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) + +#define LM_GGML_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) + +static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + lm_ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(lm_ggml_fp16_t)); + return (float)tmp; +} + +static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { + lm_ggml_fp16_t res; + lm_ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(lm_ggml_fp16_t)); + return res; +} + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#if defined(__loongarch64) +#if defined(__loongarch_asx) +#include +#endif +#if defined(__loongarch_sx) +#include +#endif +#endif + +#if defined(__loongarch_asx) + +typedef union { + int32_t i; + float f; +} ft_union; + +/* float type data load instructions */ +static __m128 __lsx_vreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +} + +static __m256 __lasx_xvreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +} +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) +#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define LM_GGML_FP16_TO_FP32(x) LM_GGML_COMPUTE_FP16_TO_FP32(x) +#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { + register double d; + register lm_ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) +#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) + +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in lm_ggml_init() +extern float lm_ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into lm_ggml_lookup_fp16_to_fp32, +// so we define LM_GGML_FP16_TO_FP32 and LM_GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(LM_GGML_FP16_TO_FP32) +inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return lm_ggml_table_f32_f16[s]; +} + +#define LM_GGML_FP16_TO_FP32(x) lm_ggml_lookup_fp16_to_fp32(x) +#endif + +#if !defined(LM_GGML_FP32_TO_FP16) +#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) +#endif + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index 7a58aec..d25c868 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -1,15 +1,17 @@ #pragma once -#include "ggml.h" - // GGML internal header +#include "ggml.h" + #include #include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ -#include #include -#include // memcpy -#include // fabsf +#include + +#ifdef __cplusplus +extern "C" { +#endif #undef MIN #undef MAX @@ -17,95 +19,8 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -#if defined(_MSC_VER) - -#define m512bh(p) p -#define m512i(p) p - -#else - -#define m512bh(p) (__m512bh)(p) -#define m512i(p) (__m512i)(p) - -#endif - -/** - * Converts brain16 to float32. - * - * The bfloat16 floating point format has the following structure: - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌──┴───┐┌─┴───┐ - * 0b0000000000000000 brain16 - * - * Since bf16 has the same number of exponent bits as a 32bit float, - * encoding and decoding numbers becomes relatively straightforward. - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌──┴───┐┌─┴───────────────────┐ - * 0b00000000000000000000000000000000 IEEE binary32 - * - * For comparison, the standard fp16 format has fewer exponent bits. - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌─┴─┐┌─┴──────┐ - * 0b0000000000000000 IEEE binary16 - * - * @see IEEE 754-2008 - */ -static inline float lm_ggml_compute_bf16_to_fp32(lm_ggml_bf16_t h) { - union { - float f; - uint32_t i; - } u; - u.i = (uint32_t)h.bits << 16; - return u.f; -} - -/** - * Converts float32 to brain16. - * - * This is binary identical with Google Brain float conversion. - * Floats shall round to nearest even, and NANs shall be quiet. - * Subnormals aren't flushed to zero, except perhaps when used. - * This code should vectorize nicely if using modern compilers. - */ -static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) { - lm_ggml_bf16_t h; - union { - float f; - uint32_t i; - } u; - u.f = s; - if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ - h.bits = (u.i >> 16) | 64; /* force to quiet */ - return h; - } - h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; - return h; -} - -#define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x) -#define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x) - -#ifdef __cplusplus -extern "C" { -#endif +// required for mmap as gguf only guarantees 32-byte alignment +#define TENSOR_ALIGNMENT 32 // static_assert should be a #define, but if it's not, // fall back to the _Static_assert C11 keyword. @@ -121,519 +36,25 @@ extern "C" { #endif #endif -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#endif - -// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available -#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __SSE3__ -#define __SSE3__ -#endif -#ifndef __SSSE3__ -#define __SSSE3__ -#endif -#endif - -#if defined(__ARM_FEATURE_SVE) -#include -#include -#endif - -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#if defined(__ARM_NEON) - -// if YCM cannot find , make a symbolic link to it, for example: // -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// logging // -#include - -#ifdef _MSC_VER - -typedef uint16_t lm_ggml_fp16_internal_t; -#define lm_ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } +LM_GGML_ATTRIBUTE_FORMAT(2, 3) +void lm_ggml_log_internal (enum lm_ggml_log_level level, const char * format, ...); +void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * text, void * user_data); -#else - -typedef __fp16 lm_ggml_fp16_internal_t; - -#define lm_ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } - -#endif // _MSC_VER - -#if !defined(__aarch64__) - -// 32-bit ARM compatibility - -// vaddvq_s16 -// vpaddq_s16 -// vpaddq_s32 -// vaddvq_s32 -// vaddvq_f32 -// vmaxvq_f32 -// vcvtnq_s32_f32 -// vzip1_u8 -// vzip2_u8 - -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { - int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); - int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); - return vcombine_s32(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[0]; res[1] = b[0]; - res[2] = a[1]; res[3] = b[1]; - res[4] = a[2]; res[5] = b[2]; - res[6] = a[3]; res[7] = b[3]; - - return res; -} - -inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[4]; res[1] = b[4]; - res[2] = a[5]; res[3] = b[5]; - res[4] = a[6]; res[5] = b[6]; - res[6] = a[7]; res[7] = b[7]; - - return res; -} - -// vld1q_s16_x2 -// vld1q_u8_x2 -// vld1q_u8_x4 -// vld1q_s8_x2 -// vld1q_s8_x4 -// TODO: double-check these work correctly - -typedef struct lm_ggml_int16x8x2_t { - int16x8_t val[2]; -} lm_ggml_int16x8x2_t; - -inline static lm_ggml_int16x8x2_t lm_ggml_vld1q_s16_x2(const int16_t * ptr) { - lm_ggml_int16x8x2_t res; - - res.val[0] = vld1q_s16(ptr + 0); - res.val[1] = vld1q_s16(ptr + 8); - - return res; -} - -typedef struct lm_ggml_uint8x16x2_t { - uint8x16_t val[2]; -} lm_ggml_uint8x16x2_t; - -inline static lm_ggml_uint8x16x2_t lm_ggml_vld1q_u8_x2(const uint8_t * ptr) { - lm_ggml_uint8x16x2_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - - return res; -} - -typedef struct lm_ggml_uint8x16x4_t { - uint8x16_t val[4]; -} lm_ggml_uint8x16x4_t; - -inline static lm_ggml_uint8x16x4_t lm_ggml_vld1q_u8_x4(const uint8_t * ptr) { - lm_ggml_uint8x16x4_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - res.val[2] = vld1q_u8(ptr + 32); - res.val[3] = vld1q_u8(ptr + 48); - - return res; -} - -typedef struct lm_ggml_int8x16x2_t { - int8x16_t val[2]; -} lm_ggml_int8x16x2_t; - -inline static lm_ggml_int8x16x2_t lm_ggml_vld1q_s8_x2(const int8_t * ptr) { - lm_ggml_int8x16x2_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - - return res; -} - -typedef struct lm_ggml_int8x16x4_t { - int8x16_t val[4]; -} lm_ggml_int8x16x4_t; - -inline static lm_ggml_int8x16x4_t lm_ggml_vld1q_s8_x4(const int8_t * ptr) { - lm_ggml_int8x16x4_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - res.val[2] = vld1q_s8(ptr + 32); - res.val[3] = vld1q_s8(ptr + 48); - - return res; -} - -// NOTE: not tested -inline static int8x16_t lm_ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { - int8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -// NOTE: not tested -inline static uint8x16_t lm_ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { - uint8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -#else - -#define lm_ggml_int16x8x2_t int16x8x2_t -#define lm_ggml_uint8x16x2_t uint8x16x2_t -#define lm_ggml_uint8x16x4_t uint8x16x4_t -#define lm_ggml_int8x16x2_t int8x16x2_t -#define lm_ggml_int8x16x4_t int8x16x4_t - -#define lm_ggml_vld1q_s16_x2 vld1q_s16_x2 -#define lm_ggml_vld1q_u8_x2 vld1q_u8_x2 -#define lm_ggml_vld1q_u8_x4 vld1q_u8_x4 -#define lm_ggml_vld1q_s8_x2 vld1q_s8_x2 -#define lm_ggml_vld1q_s8_x4 vld1q_s8_x4 -#define lm_ggml_vqtbl1q_s8 vqtbl1q_s8 -#define lm_ggml_vqtbl1q_u8 vqtbl1q_u8 - -#endif // !defined(__aarch64__) - -#if !defined(__ARM_FEATURE_DOTPROD) - -inline static int32x4_t lm_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { - const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); - const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - - return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); -} - -#else - -#define lm_ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) - -#endif // !defined(__ARM_FEATURE_DOTPROD) - -#endif // defined(__ARM_NEON) - -#if defined(__ARM_NEON) && !defined(_MSC_VER) - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) - -#define LM_GGML_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - lm_ggml_fp16_internal_t tmp; - memcpy(&tmp, &h, sizeof(lm_ggml_fp16_t)); - return (float)tmp; -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { - lm_ggml_fp16_t res; - lm_ggml_fp16_internal_t tmp = f; - memcpy(&res, &tmp, sizeof(lm_ggml_fp16_t)); - return res; -} - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#if defined(__loongarch64) -#if defined(__loongarch_asx) -#include -#endif -#if defined(__loongarch_sx) -#include -#endif -#endif - -#if defined(__loongarch_asx) - -typedef union { - int32_t i; - float f; -} ft_union; - -/* float type data load instructions */ -static __m128 __lsx_vreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); -} - -static __m256 __lasx_xvreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); -} -#endif - -#ifdef __F16C__ - -#ifdef _MSC_VER -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define LM_GGML_FP16_TO_FP32(x) LM_GGML_COMPUTE_FP16_TO_FP32(x) -#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { - register double d; - register lm_ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) - -#ifdef __ARM_FEATURE_SVE -#include -#endif // __ARM_FEATURE_SVE - -// precomputed f32 table for f16 (256 KB) -// defined in ggml.c, initialized in lm_ggml_init() -extern float lm_ggml_table_f32_f16[1 << 16]; - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into lm_ggml_lookup_fp16_to_fp32, -// so we define LM_GGML_FP16_TO_FP32 and LM_GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(LM_GGML_FP16_TO_FP32) -inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return lm_ggml_table_f32_f16[s]; -} - -#define LM_GGML_FP16_TO_FP32(x) lm_ggml_lookup_fp16_to_fp32(x) -#endif - -#if !defined(LM_GGML_FP32_TO_FP16) -#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) -#endif +#define LM_GGML_LOG(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_NONE , __VA_ARGS__) +#define LM_GGML_LOG_INFO(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LM_GGML_LOG_WARN(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LM_GGML_LOG_ERROR(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LM_GGML_LOG_DEBUG(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LM_GGML_LOG_CONT(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_CONT , __VA_ARGS__) // bitset +typedef uint32_t lm_ggml_bitset_t; + static_assert(sizeof(lm_ggml_bitset_t) == 4, "bitset_t constants must be updated"); #define BITSET_SHR 5 // log2(sizeof(lm_ggml_bitset_t)*8) #define BITSET_MASK (sizeof(lm_ggml_bitset_t)*8 - 1) @@ -659,6 +80,12 @@ static inline void lm_ggml_bitset_clear(lm_ggml_bitset_t * bitset, size_t i) { #define LM_GGML_HASHSET_FULL ((size_t)-1) #define LM_GGML_HASHSET_ALREADY_EXISTS ((size_t)-2) +struct lm_ggml_hash_set { + size_t size; + lm_ggml_bitset_t * used; // whether or not the keys are in use i.e. set + struct lm_ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if lm_ggml_bitset_get(used, i) +}; + struct lm_ggml_hash_set lm_ggml_hash_set_new(size_t size); void lm_ggml_hash_set_free(struct lm_ggml_hash_set * hash_set); @@ -748,6 +175,35 @@ static size_t lm_ggml_hash_find_or_insert(struct lm_ggml_hash_set * hash_set, st LM_GGML_ABORT("fatal error"); } +// computation graph + +enum lm_ggml_cgraph_eval_order { + LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + LM_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + LM_GGML_CGRAPH_EVAL_ORDER_COUNT +}; + +struct lm_ggml_cgraph { + int size; + int n_nodes; + int n_leafs; + + struct lm_ggml_tensor ** nodes; + struct lm_ggml_tensor ** grads; + struct lm_ggml_tensor ** leafs; + + struct lm_ggml_hash_set visited_hash_set; + + enum lm_ggml_cgraph_eval_order order; +}; + +struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph, int i0, int i1); + +// Memory allocation + +void * lm_ggml_aligned_malloc(size_t size); +void lm_ggml_aligned_free(void * ptr, size_t size); + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index 7bc21a9..28fc819 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -1,3 +1,5 @@ +// Note: this description is outdated +// // An interface allowing to compute lm_ggml_cgraph with Metal // // This is a fully functional interface that extends ggml with GPU support for Apple devices. @@ -25,9 +27,6 @@ #include #include -// max memory buffers that can be mapped to the device -#define LM_GGML_METAL_MAX_BUFFERS 64 - struct lm_ggml_tensor; struct lm_ggml_cgraph; @@ -40,19 +39,17 @@ extern "C" { // user-code should use only these functions // -LM_GGML_API void lm_ggml_backend_metal_log_set_callback(lm_ggml_log_callback log_callback, void * user_data); - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_metal_init(void); LM_GGML_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend); -LM_GGML_API LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); - -LM_GGML_API void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb); +LM_GGML_DEPRECATED( + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), + "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713"); LM_GGML_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data); -LM_GGML_API LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void); +LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family // ideally, the user code should be doing these checks @@ -62,6 +59,8 @@ LM_GGML_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called LM_GGML_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend); +LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void); + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index a04db3f..d0a084e 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -1,7 +1,7 @@ #import "ggml-metal.h" +#import "ggml-impl.h" #import "ggml-backend-impl.h" -#import "ggml.h" #import @@ -12,18 +12,77 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -#ifdef LM_GGML_METAL_NDEBUG -#define LM_GGML_METAL_LOG_INFO(...) -#define LM_GGML_METAL_LOG_WARN(...) -#define LM_GGML_METAL_LOG_ERROR(...) -#else -#define LM_GGML_METAL_LOG_INFO(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define LM_GGML_METAL_LOG_WARN(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define LM_GGML_METAL_LOG_ERROR(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) -#endif +// max memory buffers that can be mapped to the device +#define LM_GGML_METAL_MAX_BUFFERS 64 + +// max number of MTLCommandBuffer used to submit a graph for processing +#define LM_GGML_METAL_MAX_COMMAND_BUFFERS 8 #define UNUSED(x) (void)(x) +// globals + +// overload of MTLGPUFamilyMetal3 (not available in some environments) +static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + +// initialized in lm_ggml_backend_metal_reg +static struct lm_ggml_backend_reg g_lm_ggml_backend_metal_reg; +static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device; + +// information about a Metal device +// note: assumes single GPU device - the default one +// TODO: support multiple GPU devices +static struct lm_ggml_backend_metal_device_context { + id mtl_device; + int mtl_device_ref_count; + + bool support_simdgroup_reduction; + bool support_simdgroup_mm; + + char name[128]; +} g_lm_ggml_ctx_dev_main = { + /*.mtl_device =*/ nil, + /*.mtl_device_ref_count =*/ 0, + /*.support_simdgroup_reduction =*/ false, + /*.support_simdgroup_mm =*/ false, + /*.name =*/ "", +}; + +// acquire +static id lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + + if (ctx->mtl_device == nil) { + ctx->mtl_device = MTLCreateSystemDefaultDevice(); + + ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + + ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + + strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); + } + + ctx->mtl_device_ref_count++; + + return ctx->mtl_device; +} + +// release +static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + assert(ctx->mtl_device_ref_count > 0); + + ctx->mtl_device_ref_count--; + + if (ctx->mtl_device_ref_count == 0) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } +} + +// kernels + struct lm_ggml_metal_kernel { id pipeline; }; @@ -31,6 +90,8 @@ enum lm_ggml_metal_kernel_type { LM_GGML_METAL_KERNEL_TYPE_ADD, LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, + LM_GGML_METAL_KERNEL_TYPE_SUB, + LM_GGML_METAL_KERNEL_TYPE_SUB_ROW, LM_GGML_METAL_KERNEL_TYPE_MUL, LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, LM_GGML_METAL_KERNEL_TYPE_DIV, @@ -82,6 +143,8 @@ LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, LM_GGML_METAL_KERNEL_TYPE_NORM, + LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, + LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -178,6 +241,8 @@ LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, + LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, + LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, LM_GGML_METAL_KERNEL_TYPE_PAD_F32, LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, @@ -205,25 +270,42 @@ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, LM_GGML_METAL_KERNEL_TYPE_CONCAT, LM_GGML_METAL_KERNEL_TYPE_SQR, + LM_GGML_METAL_KERNEL_TYPE_SQRT, + LM_GGML_METAL_KERNEL_TYPE_SIN, + LM_GGML_METAL_KERNEL_TYPE_COS, LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, + LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, LM_GGML_METAL_KERNEL_TYPE_COUNT }; struct lm_ggml_backend_metal_context { - int n_cb; - - id device; id queue; dispatch_queue_t d_queue; struct lm_ggml_metal_kernel kernels[LM_GGML_METAL_KERNEL_TYPE_COUNT]; - bool support_simdgroup_reduction; - bool support_simdgroup_mm; + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; - bool should_capture_next_compute; + struct lm_ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1]; // abort lm_ggml_metal_graph_compute if callback returns true lm_ggml_abort_callback abort_callback; @@ -241,51 +323,19 @@ @interface LMGGMLMetalClass : NSObject @implementation LMGGMLMetalClass @end -static void lm_ggml_metal_default_log_callback(enum lm_ggml_log_level level, const char * msg, void * user_data) { - fprintf(stderr, "%s", msg); - - UNUSED(level); - UNUSED(user_data); -} - -lm_ggml_log_callback lm_ggml_metal_log_callback = lm_ggml_metal_default_log_callback; -void * lm_ggml_metal_log_user_data = NULL; - -LM_GGML_ATTRIBUTE_FORMAT(2, 3) -static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, ...){ - if (lm_ggml_metal_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - lm_ggml_metal_log_callback(level, buffer, lm_ggml_metal_log_user_data); - } else { - char* buffer2 = malloc(len+1); - va_end(args); - va_start(args, format); - vsnprintf(buffer2, len+1, format, args); - buffer2[len] = 0; - lm_ggml_metal_log_callback(level, buffer2, lm_ggml_metal_log_user_data); - free(buffer2); - } - va_end(args); - } -} - static void * lm_ggml_metal_host_malloc(size_t n) { void * data = NULL; #if TARGET_OS_OSX kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); if (err != KERN_SUCCESS) { - LM_GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + LM_GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); return NULL; } #else const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { - LM_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + LM_GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); return NULL; } #endif @@ -293,27 +343,26 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, return data; } -static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(int n_cb) { - LM_GGML_METAL_LOG_INFO("%s: allocating\n", __func__); +static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend_dev_t dev) { + LM_GGML_LOG_INFO("%s: allocating\n", __func__); #if TARGET_OS_OSX && !LM_GGML_METAL_NDEBUG // Show all the Metal device instances in the system NSArray * devices = MTLCopyAllDevices(); for (id device in devices) { - LM_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); + LM_GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); } [devices release]; // since it was created by a *Copy* C method #endif - // Pick and show default Metal device - id device = MTLCreateSystemDefaultDevice(); - LM_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - - // Configure context + // init context struct lm_ggml_backend_metal_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_context)); - ctx->device = device; - ctx->n_cb = MIN(n_cb, LM_GGML_METAL_MAX_BUFFERS); - ctx->queue = [ctx->device newCommandQueue]; + struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context; + + id device = lm_ggml_backend_metal_device_acq(ctx_dev); + LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); + + ctx->queue = [device newCommandQueue]; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); id metal_library; @@ -344,28 +393,28 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, if (try_metallib && path_lib != nil) { // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:path_lib]; - LM_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + LM_GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); - metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; + metal_library = [device newLibraryWithURL:libURL error:&error]; if (error) { - LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } else { #if LM_GGML_METAL_EMBED_LIBRARY - LM_GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__); + LM_GGML_LOG_INFO("%s: using embedded metal library\n", __func__); extern const char lm_ggml_metallib_start[]; extern const char lm_ggml_metallib_end[]; NSString * src = [[NSString alloc] initWithBytes:lm_ggml_metallib_start length:(lm_ggml_metallib_end-lm_ggml_metallib_start) encoding:NSUTF8StringEncoding]; #else - LM_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + LM_GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); NSString * path_source; NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"LM_GGML_METAL_PATH_RESOURCES"]; - LM_GGML_METAL_LOG_INFO("%s: LM_GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); + LM_GGML_LOG_INFO("%s: LM_GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); if (path_resource) { path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; @@ -374,15 +423,15 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, } if (path_source == nil) { - LM_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + LM_GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); path_source = @"ggml-metal.metal"; } - LM_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); + LM_GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; if (error) { - LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } #endif // LM_GGML_METAL_EMBED_LIBRARY @@ -396,9 +445,9 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, //[options setFastMathEnabled:false]; - metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; + metal_library = [device newLibraryWithSource:src options:options error:&error]; if (error) { - LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } @@ -406,56 +455,51 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, } // print MTL GPU family: - LM_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); - - const NSInteger MTLGPUFamilyMetal3 = 5001; + LM_GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf { for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([ctx->device supportsFamily:i]) { - LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + if ([device supportsFamily:i]) { + LM_GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); break; } } for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { - if ([ctx->device supportsFamily:i]) { - LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + if ([device supportsFamily:i]) { + LM_GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); break; } } - for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) { - if ([ctx->device supportsFamily:i]) { - LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i); + for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { + if ([device supportsFamily:i]) { + LM_GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); break; } } } - ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7]; - ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3]; + LM_GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); + LM_GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); + LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); - ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7]; + ctx->capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; - LM_GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false"); - LM_GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false"); - LM_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - - ctx->should_capture_next_compute = false; + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - LM_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); - } -#elif TARGET_OS_OSX - if (ctx->device.maxTransferRate != 0) { - LM_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6); - } else { - LM_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); + LM_GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); } #endif @@ -468,7 +512,7 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, } /* - LM_GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + LM_GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -476,21 +520,26 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, if (supported) { \ struct lm_ggml_metal_kernel * kernel = &ctx->kernels[e]; \ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ + kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ if (error) { \ - LM_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + LM_GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ [metal_library release]; \ return NULL; \ } \ } else { \ - LM_GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ + LM_GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } + const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; + const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; + // simd_sum and simd_max requires MTLGPUFamilyApple7 LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB, sub, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true); @@ -511,10 +560,10 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -539,105 +588,109 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); @@ -645,14 +698,14 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -665,22 +718,29 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } [metal_library release]; + return ctx; } static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) { - LM_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); + LM_GGML_LOG_INFO("%s: deallocating\n", __func__); for (int i = 0; i < LM_GGML_METAL_KERNEL_TYPE_COUNT; ++i) { [ctx->kernels[i].pipeline release]; } + Block_release(ctx->encode_async); + [ctx->queue release]; - [ctx->device release]; dispatch_release(ctx->d_queue); @@ -711,7 +771,7 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) { // Metal buffer based on the host memory pointer // static id lm_ggml_metal_get_buffer(struct lm_ggml_tensor * t, size_t * offs) { - //LM_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + //LM_GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); const int64_t tsize = lm_ggml_nbytes(t); @@ -723,28 +783,31 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) { for (int i = 0; i < buf_ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - //LM_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); + //LM_GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { *offs = (size_t) ioffs; - //LM_GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + //LM_GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); return buf_ctx->buffers[i].metal; } } - LM_GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + LM_GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); return nil; } -static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context * ctx, const struct lm_ggml_tensor * op) { +static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_context * ctx_dev, const struct lm_ggml_tensor * op) { for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) { return false; } } + const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; + const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; + switch (op->op) { case LM_GGML_OP_UNARY: switch (lm_ggml_get_unary_op(op)) { @@ -765,26 +828,32 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context case LM_GGML_OP_PERMUTE: case LM_GGML_OP_CONCAT: case LM_GGML_OP_ADD: + case LM_GGML_OP_SUB: case LM_GGML_OP_ACC: case LM_GGML_OP_MUL: case LM_GGML_OP_DIV: case LM_GGML_OP_REPEAT: case LM_GGML_OP_SCALE: case LM_GGML_OP_CLAMP: + return true; case LM_GGML_OP_SQR: + case LM_GGML_OP_SQRT: + case LM_GGML_OP_SIN: + case LM_GGML_OP_COS: + return lm_ggml_is_contiguous(op->src[0]); case LM_GGML_OP_SUM_ROWS: - return true; case LM_GGML_OP_SOFT_MAX: case LM_GGML_OP_RMS_NORM: case LM_GGML_OP_GROUP_NORM: - return ctx->support_simdgroup_reduction; + return support_simdgroup_reduction; case LM_GGML_OP_NORM: case LM_GGML_OP_ROPE: - case LM_GGML_OP_IM2COL: return true; + case LM_GGML_OP_IM2COL: + return op->src[0]->type == LM_GGML_TYPE_F16; case LM_GGML_OP_POOL_1D: - case LM_GGML_OP_POOL_2D: return false; + case LM_GGML_OP_POOL_2D: case LM_GGML_OP_UPSCALE: case LM_GGML_OP_PAD: case LM_GGML_OP_ARANGE: @@ -802,19 +871,13 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context if (op->src[0]->ne[0] == 256) { return false; } - { - float logit_softcap; - - memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - return false; - } - } - return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + return support_simdgroup_mm; // TODO: over-restricted for vec-kernels + case LM_GGML_OP_SSM_CONV: + case LM_GGML_OP_SSM_SCAN: + return true; case LM_GGML_OP_MUL_MAT: case LM_GGML_OP_MUL_MAT_ID: - return ctx->support_simdgroup_reduction && + return support_simdgroup_reduction && (op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type == LM_GGML_TYPE_F32); case LM_GGML_OP_CPY: case LM_GGML_OP_DUP: @@ -857,715 +920,828 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context } } -static enum lm_ggml_status lm_ggml_metal_graph_compute( - struct lm_ggml_backend_metal_context * ctx, - struct lm_ggml_cgraph * gf) { - - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; +static void lm_ggml_metal_encode_node( + lm_ggml_backend_t backend, + int idx, + id encoder) { + struct lm_ggml_backend_metal_context * ctx = backend->context; + struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context; - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel + struct lm_ggml_cgraph * gf = ctx->gf; - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; + struct lm_ggml_tensor * node = lm_ggml_graph_node(gf, idx); - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; + //LM_GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, lm_ggml_op_name(node->op)); - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; + struct lm_ggml_tensor * src0 = node->src[0]; + struct lm_ggml_tensor * src1 = node->src[1]; + struct lm_ggml_tensor * src2 = node->src[2]; + struct lm_ggml_tensor * dst = node; - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - LM_GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - LM_GGML_ABORT("capture failed"); - } + if (lm_ggml_is_empty(dst)) { + return; } - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } + switch (dst->op) { + case LM_GGML_OP_NONE: + case LM_GGML_OP_RESHAPE: + case LM_GGML_OP_VIEW: + case LM_GGML_OP_TRANSPOSE: + case LM_GGML_OP_PERMUTE: + { + // noop -> next node + } return; + default: + { + } break; } - const id *command_buffers = command_buffer_builder; - - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; - - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; - - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + if (!lm_ggml_metal_supports_op(ctx_dev, dst)) { + LM_GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, lm_ggml_op_desc(dst)); + LM_GGML_ABORT("unsupported op"); + } - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; + + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; LM_GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; LM_GGML_UNUSED(ne23); + + const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum lm_ggml_type src0t = src0 ? src0->type : LM_GGML_TYPE_COUNT; + const enum lm_ggml_type src1t = src1 ? src1->type : LM_GGML_TYPE_COUNT; + const enum lm_ggml_type dstt = dst ? dst->type : LM_GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? lm_ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? lm_ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? lm_ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? lm_ggml_metal_get_buffer(dst, &offs_dst) : nil; - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } +#if 0 + LM_GGML_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op)); + if (src0) { + LM_GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + lm_ggml_is_contiguous(src0), src0->name); + } + if (src1) { + LM_GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + lm_ggml_is_contiguous(src1), src1->name); + } + if (dst) { + LM_GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif - //LM_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, lm_ggml_op_name(gf->nodes[i]->op)); + id device = ctx_dev->mtl_device; - struct lm_ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct lm_ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct lm_ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct lm_ggml_tensor * dst = gf->nodes[i]; + switch (dst->op) { + case LM_GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_ADD: + case LM_GGML_OP_SUB: + case LM_GGML_OP_MUL: + case LM_GGML_OP_DIV: + { + LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - if (lm_ggml_is_empty(dst)) { - continue; - } + const size_t offs = 0; - switch (dst->op) { - case LM_GGML_OP_NONE: - case LM_GGML_OP_RESHAPE: - case LM_GGML_OP_VIEW: - case LM_GGML_OP_TRANSPOSE: - case LM_GGML_OP_PERMUTE: - { - // noop -> next node - } continue; - default: - { - } break; - } + bool bcast_row = false; - if (!lm_ggml_metal_supports_op(ctx, dst)) { - LM_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, lm_ggml_op_desc(dst)); - LM_GGML_ABORT("unsupported op"); - } + int64_t nb = ne00; // used by the "row" kernels - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } + id pipeline = nil; - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; LM_GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; LM_GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum lm_ggml_type src0t = src0 ? src0->type : LM_GGML_TYPE_COUNT; - const enum lm_ggml_type src1t = src1 ? src1->type : LM_GGML_TYPE_COUNT; - const enum lm_ggml_type dstt = dst ? dst->type : LM_GGML_TYPE_COUNT; - - id id_src0 = src0 ? lm_ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? lm_ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? lm_ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? lm_ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //LM_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op)); - //if (src0) { - // LM_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02, - // lm_ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // LM_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12, - // lm_ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // LM_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case LM_GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_ADD: - case LM_GGML_OP_MUL: - case LM_GGML_OP_DIV: - { - LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - const size_t offs = 0; + // src1 is a row + LM_GGML_ASSERT(ne11 == 1); - bool bcast_row = false; + nb = ne00 / 4; + switch (dst->op) { + case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case LM_GGML_OP_SUB: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; + case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + } - int64_t nb = ne00; // used by the "row" kernels + bcast_row = true; + } else { + switch (dst->op) { + case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case LM_GGML_OP_SUB: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + } + } - id pipeline = nil; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = lm_ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case LM_GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case LM_GGML_TYPE_I16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + } - if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_ACC: + { + LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(dstt == LM_GGML_TYPE_F32); + + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + + const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; + const size_t offs = ((const int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } - // src1 is a row - LM_GGML_ASSERT(ne11 == 1); + const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_SCALE: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - nb = ne00 / 4; - switch (dst->op) { - case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - } + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); - bcast_row = true; - } else { - switch (dst->op) { - case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - } - } + int64_t n = lm_ggml_nelements(dst); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = lm_ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + id pipeline = nil; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case LM_GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case LM_GGML_TYPE_I16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - } + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_ACC: - { - LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(dstt == LM_GGML_TYPE_F32); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; + float min; + float max; + memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel + const int64_t n = lm_ggml_nelements(dst); - const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_UNARY: + switch (lm_ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } + case LM_GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TANH].pipeline; - const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_SCALE: - { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + const int64_t n = lm_ggml_nelements(dst); - int64_t n = lm_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RELU].pipeline; - id pipeline = nil; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } + const int64_t n = lm_ggml_nelements(dst); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + const int64_t n = lm_ggml_nelements(dst); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_GELU: + { + int64_t n = lm_ggml_nelements(dst); - const int64_t n = lm_ggml_nelements(dst); + id pipeline = nil; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_OP_UNARY: - switch (lm_ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + if (n % 4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } - case LM_GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TANH].pipeline; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = lm_ggml_nelements(dst); - const int64_t n = lm_ggml_nelements(dst); + id pipeline = nil; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RELU].pipeline; + if (n % 4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = lm_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_SILU: + { + int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + id pipeline = nil; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (n % 4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } - const int64_t n = lm_ggml_nelements(dst); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_UNARY_OP_GELU: - { - int64_t n = lm_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op)); + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_OP_SQR: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - id pipeline = nil; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SQR].pipeline; - if (n % 4 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = lm_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_SQRT: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - id pipeline = nil; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SQRT].pipeline; - if (n % 4 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_UNARY_OP_SILU: - { - int64_t n = lm_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_SIN: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - id pipeline = nil; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SIN].pipeline; - if (n % 4 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - LM_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, lm_ggml_op_name(dst->op)); - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_OP_SQR: - { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_COS: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SQR].pipeline; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_COS].pipeline; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = lm_ggml_nelements(dst); + const int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_OP_SUM_ROWS: - { - LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_OP_SOFT_MAX: - { - LM_GGML_ASSERT(!src1 || src1->type == LM_GGML_TYPE_F16 || src1->type == LM_GGML_TYPE_F32); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_SUM_ROWS: + { + LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_SOFT_MAX: + { + LM_GGML_ASSERT(!src1 || src1->type == LM_GGML_TYPE_F16 || src1->type == LM_GGML_TYPE_F32); - int nth = 32; // SIMD width + int nth = 32; // SIMD width - id pipeline = nil; + id pipeline = nil; - const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16); - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } - float scale; - float max_bias; + float scale; + float max_bias; - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = lm_ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t nrows_x = lm_ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((const int32_t *)(dst->op_params))[0]; - id pipeline = nil; + id pipeline = nil; - if (ne00%8 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } + if (ne00%8 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case LM_GGML_OP_MUL_MAT: - { - LM_GGML_ASSERT(ne00 == ne10); + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case LM_GGML_OP_SSM_CONV: + { + LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_SSM_SCAN: + { + struct lm_ggml_tensor * src3 = node->src[3]; + struct lm_ggml_tensor * src4 = node->src[4]; + struct lm_ggml_tensor * src5 = node->src[5]; + + LM_GGML_ASSERT(src3); + LM_GGML_ASSERT(src4); + LM_GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? lm_ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? lm_ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; LM_GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; LM_GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; LM_GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; LM_GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; LM_GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; LM_GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; LM_GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; LM_GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_MUL_MAT: + { + LM_GGML_ASSERT(ne00 == ne10); - LM_GGML_ASSERT(ne12 % ne02 == 0); - LM_GGML_ASSERT(ne13 % ne03 == 0); + LM_GGML_ASSERT(ne12 % ne02 == 0); + LM_GGML_ASSERT(ne13 % ne03 == 0); - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach + if ([device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case LM_GGML_TYPE_F16: ne11_mm_min = 2; break; case LM_GGML_TYPE_Q8_0: ne11_mm_min = 7; break; @@ -1585,12 +1761,12 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !lm_ggml_is_transposed(src0) && - !lm_ggml_is_transposed(src1) && - src1t == LM_GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) { + if ([device supportsFamily:MTLGPUFamilyApple7] && + !lm_ggml_is_transposed(src0) && + !lm_ggml_is_transposed(src1) && + src1t == LM_GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers @@ -1636,14 +1812,16 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -1797,7 +1975,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( } break; default: { - LM_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); LM_GGML_ABORT("not implemented"); } }; @@ -1812,16 +1990,18 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || @@ -1859,1036 +2039,1207 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t); - - LM_GGML_ASSERT(src2t == LM_GGML_TYPE_I32); - - LM_GGML_ASSERT(!lm_ggml_is_transposed(src0)); - LM_GGML_ASSERT(!lm_ggml_is_transposed(src1)); - - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - LM_GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; - case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: LM_GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + } break; + case LM_GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t); + + LM_GGML_ASSERT(src2t == LM_GGML_TYPE_I32); + + LM_GGML_ASSERT(!lm_ggml_is_transposed(src0)); + LM_GGML_ASSERT(!lm_ggml_is_transposed(src1)); + + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + + LM_GGML_ASSERT(ne03 == 1); + LM_GGML_ASSERT(ne13 == 1); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + LM_GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min) { + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; + case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } - id pipeline = nil; + id pipeline = nil; + + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + default: LM_GGML_ABORT("MUL_MAT_ID not implemented"); + } - // use custom matrix x vector kernel - switch (src0t) { - case LM_GGML_TYPE_F32: - { - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case LM_GGML_TYPE_F16: - { - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case LM_GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case LM_GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - default: - { - LM_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - LM_GGML_ABORT("not implemented"); - } - }; - - if (lm_ggml_is_quantized(src0t)) { - LM_GGML_ASSERT(ne00 >= nth0*nth1); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case LM_GGML_TYPE_F16: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case LM_GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case LM_GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + default: + { + LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); + LM_GGML_ABORT("not implemented"); } + }; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; + if (lm_ggml_is_quantized(src0t)) { + LM_GGML_ASSERT(ne00 >= nth0*nth1); + } - if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || - src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || - src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) { - const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case LM_GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: LM_GGML_ABORT("not implemented"); - } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || + src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || + src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) { + const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case LM_GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case LM_GGML_OP_RMS_NORM: - { - LM_GGML_ASSERT(ne00 % 4 == 0); - LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case LM_GGML_OP_RMS_NORM: + { + LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); - int nth = 32; // SIMD width + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - const int64_t nrows = lm_ggml_nrows(src0); + const int64_t nrows = lm_ggml_nrows(src0); - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_GROUP_NORM: - { - LM_GGML_ASSERT(ne00 % 4 == 0); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_NORM: - { - LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_GROUP_NORM: + { + LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); - const int nth = MIN(256, ne00); + const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} - const int64_t nrows = lm_ggml_nrows(src0); + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_ROPE: - { - LM_GGML_ASSERT(ne10 == ne02); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - const int nth = MIN(1024, ne00); + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_NORM: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; + const int nth = MIN(256, ne00); - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline; - const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - id pipeline = nil; + const int64_t nrows = lm_ggml_nrows(src0); - if (!is_neox) { - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - }; - } + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_ROPE: + { + LM_GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((const int32_t *) dst->op_params)[0]; + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + }; + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_IM2COL: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16 || dst->type == LM_GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case LM_GGML_OP_UPSCALE: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_IM2COL: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16 || dst->type == LM_GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; + + const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; + + switch (dst->type) { + case LM_GGML_TYPE_F32: { + pipeline = (is_gt_mttpt ? + ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline + : + ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); } break; - case LM_GGML_OP_PAD: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + case LM_GGML_TYPE_F16: { + pipeline = (is_gt_mttpt ? + ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline + : + ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); } break; - case LM_GGML_OP_ARANGE: - { - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); + default: LM_GGML_ABORT("fatal error"); + }; - float start; - float step; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; + + if (is_gt_mttpt) { + [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; + [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; + [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; + + const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); + + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } + } break; + case LM_GGML_OP_UPSCALE: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_PAD: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_ARANGE: + { + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); + float start; + float step; - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - const int nth = MIN(1024, ne0); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_TIMESTEP_EMBEDDING: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_TIMESTEP_EMBEDDING: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; - const int half = dim / 2; + const int half = dim / 2; - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - const int nth = MIN(1024, half); + const int nth = MIN(1024, half); - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case LM_GGML_OP_ARGSORT: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_I32); + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_ARGSORT: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_I32); - const int nrows = lm_ggml_nrows(src0); + const int nrows = lm_ggml_nrows(src0); - enum lm_ggml_sort_order order = (enum lm_ggml_sort_order) dst->op_params[0]; + enum lm_ggml_sort_order order = (enum lm_ggml_sort_order) dst->op_params[0]; - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = LM_GGML_PAD(ne00_padded*sizeof(int32_t), 16); + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = LM_GGML_PAD(ne00_padded*sizeof(int32_t), 16); - id pipeline = nil; + id pipeline = nil; - switch (order) { - case LM_GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case LM_GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: LM_GGML_ABORT("fatal error"); - }; + switch (order) { + case LM_GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case LM_GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + }; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case LM_GGML_OP_LEAKY_RELU: - { - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case LM_GGML_OP_LEAKY_RELU: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - const int64_t n = lm_ggml_nelements(dst); + const int64_t n = lm_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case LM_GGML_OP_FLASH_ATTN_EXT: - { - LM_GGML_ASSERT(ne00 % 4 == 0); - LM_GGML_ASSERT(ne11 % 32 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(ne11 % 32 == 0); - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2)); + LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2)); - struct lm_ggml_tensor * src3 = gf->nodes[i]->src[3]; + struct lm_ggml_tensor * src3 = node->src[3]; - size_t offs_src3 = 0; + size_t offs_src3 = 0; - id id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil; - LM_GGML_ASSERT(!src3 || src3->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(!src3 || src3->ne[1] >= LM_GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + LM_GGML_ASSERT(!src3 || src3->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(!src3 || src3->ne[1] >= LM_GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const int64_t ne30 = src3 ? src3->ne[0] : 0; LM_GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; LM_GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; LM_GGML_UNUSED(ne33); + const int64_t ne30 = src3 ? src3->ne[0] : 0; LM_GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; LM_GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; LM_GGML_UNUSED(ne33); - const uint64_t nb30 = src3 ? src3->nb[0] : 0; LM_GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; LM_GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; LM_GGML_UNUSED(nb33); + const uint64_t nb30 = src3 ? src3->nb[0] : 0; LM_GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; LM_GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; LM_GGML_UNUSED(nb33); - const enum lm_ggml_type src2t = src2 ? src2->type : LM_GGML_TYPE_COUNT; LM_GGML_UNUSED(src2t); + const enum lm_ggml_type src2t = src2 ? src2->type : LM_GGML_TYPE_COUNT; LM_GGML_UNUSED(src2t); - float scale; - float max_bias; + float scale; + float max_bias; + float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - id pipeline = nil; + id pipeline = nil; - bool use_vec_kernel = false; + bool use_vec_kernel = false; - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - LM_GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; - switch (ne00) { - case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - LM_GGML_ABORT("add template specialization for this size"); - } - } + switch (ne00) { + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + LM_GGML_ASSERT(nqptg <= 32); + LM_GGML_ASSERT(nqptg % 8 == 0); + LM_GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > device.maxThreadgroupMemoryLength) { + break; } + nsgmax *= 2; + } + nsgmax /= 2; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); + LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + LM_GGML_ASSERT(nqptg <= 32); + LM_GGML_ASSERT(nqptg % 1 == 0); + LM_GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); + LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case LM_GGML_OP_DUP: + case LM_GGML_OP_CPY: + case LM_GGML_OP_CONT: + { + LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/lm_ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(ne0 % lm_ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + }; + } break; + case LM_GGML_TYPE_F16: + { + switch (dstt) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + }; + } break; + default: LM_GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_POOL_2D: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32 && src0t == dstt); + + const int32_t * opts = dst->op_params; + enum lm_ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case LM_GGML_TYPE_F32: { + switch(op) { + case LM_GGML_OP_POOL_AVG: + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; + case LM_GGML_OP_POOL_MAX: + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; + default: LM_GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - LM_GGML_ASSERT(nqptg <= 32); - LM_GGML_ASSERT(nqptg % 8 == 0); - LM_GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; + } break; + default: LM_GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; + [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; + [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; + [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; + [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; + default: + { + LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op)); + LM_GGML_ABORT("fatal error"); + } + } +} + +static enum lm_ggml_status lm_ggml_metal_graph_compute( + lm_ggml_backend_t backend, + struct lm_ggml_cgraph * gf) { + struct lm_ggml_backend_metal_context * ctx = backend->context; + struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + @autoreleasepool { + ctx->gf = gf; - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - LM_GGML_ASSERT(nqptg <= 32); - LM_GGML_ASSERT(nqptg % 1 == 0); - LM_GGML_ASSERT(ncpsg % 32 == 0); + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + LM_GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } + } + } - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case LM_GGML_OP_DUP: - case LM_GGML_OP_CPY: - case LM_GGML_OP_CONT: - { - LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/lm_ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case LM_GGML_TYPE_F32: - { - LM_GGML_ASSERT(ne0 % lm_ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: LM_GGML_ABORT("not implemented"); - }; - } break; - case LM_GGML_TYPE_F16: - { - switch (dstt) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: LM_GGML_ABORT("not implemented"); - }; - } break; - default: LM_GGML_ABORT("not implemented"); - } + [command_buffer enqueue]; + ctx->encode_async(n_cb); + } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - LM_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, lm_ggml_op_name(dst->op)); - LM_GGML_ABORT("fatal error"); - } - } + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; - if (should_capture) { - [encoder popDebugGroup]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; } } - [encoder endEncoding]; + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return LM_GGML_STATUS_FAILED; + } } - }); - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - LM_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - NSString * error_code = [command_buffer error].localizedDescription; - LM_GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); + return LM_GGML_STATUS_FAILED; } - return LM_GGML_STATUS_FAILED; - } + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; - } + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + LM_GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return LM_GGML_STATUS_ABORTED; + } - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - LM_GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return LM_GGML_STATUS_ABORTED; + [next_buffer commit]; } - [next_buffer commit]; - } - - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } } - } return LM_GGML_STATUS_SUCCESS; } @@ -2896,44 +3247,13 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( // backend interface -// default buffer -static id g_backend_device = nil; -static int g_backend_device_ref_count = 0; - -static id lm_ggml_backend_metal_get_device(void) { - if (g_backend_device == nil) { - g_backend_device = MTLCreateSystemDefaultDevice(); - } - - g_backend_device_ref_count++; - - return g_backend_device; -} - -static void lm_ggml_backend_metal_free_device(void) { - assert(g_backend_device_ref_count > 0); - - g_backend_device_ref_count--; - - if (g_backend_device_ref_count == 0) { - [g_backend_device release]; - g_backend_device = nil; - } -} - -LM_GGML_CALL static const char * lm_ggml_backend_metal_buffer_get_name(lm_ggml_backend_buffer_t buffer) { - return "Metal"; - - UNUSED(buffer); -} - -LM_GGML_CALL static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { +static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context; for (int i = 0; i < ctx->n_buffers; i++) { [ctx->buffers[i].metal release]; } - lm_ggml_backend_metal_free_device(); + lm_ggml_backend_metal_device_rel(buffer->buft->device->context); if (ctx->owned) { #if TARGET_OS_OSX @@ -2946,25 +3266,25 @@ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backen free(ctx); } -LM_GGML_CALL static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buffer) { +static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buffer) { struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context; return ctx->all_data; } -LM_GGML_CALL static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); UNUSED(buffer); } -LM_GGML_CALL static void lm_ggml_backend_metal_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void lm_ggml_backend_metal_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); UNUSED(buffer); } -LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { +static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { if (lm_ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, lm_ggml_nbytes(src)); return true; @@ -2974,17 +3294,17 @@ LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend UNUSED(buffer); } -LM_GGML_CALL static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { +static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context; memset(ctx->all_data, value, ctx->all_size); } static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = { - /* .get_name = */ lm_ggml_backend_metal_buffer_get_name, /* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer, /* .get_base = */ lm_ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, /* .set_tensor = */ lm_ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ lm_ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor = */ lm_ggml_backend_metal_buffer_cpy_tensor, @@ -2994,7 +3314,7 @@ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buff // default buffer type -LM_GGML_CALL static const char * lm_ggml_backend_metal_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { +static const char * lm_ggml_backend_metal_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { return "Metal"; UNUSED(buft); @@ -3004,19 +3324,17 @@ static void lm_ggml_backend_metal_log_allocated_size(id device, size_ #ifndef LM_GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + LM_GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", __func__, size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { - LM_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); - } else { - LM_GGML_METAL_LOG_INFO("\n"); + LM_GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); } } else { - LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + LM_GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", __func__, size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0); @@ -3027,8 +3345,8 @@ static void lm_ggml_backend_metal_log_allocated_size(id device, size_ UNUSED(size_aligned); } -LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - struct lm_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct lm_ggml_backend_metal_buffer_context)); +static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + struct lm_ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_buffer_context)); const size_t size_page = sysconf(_SC_PAGESIZE); @@ -3037,7 +3355,7 @@ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_a size_aligned += (size_page - (size_aligned % size_page)); } - id device = lm_ggml_backend_metal_get_device(); + id device = lm_ggml_backend_metal_device_acq(buft->device->context); ctx->all_data = lm_ggml_metal_host_malloc(size_aligned); ctx->all_size = size_aligned; @@ -3045,18 +3363,22 @@ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_a ctx->n_buffers = 1; if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; - ctx->buffers[0].size = size; - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } } - if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) { - LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); free(ctx); - lm_ggml_backend_metal_free_device(); + lm_ggml_backend_metal_device_rel(buft->device->context); return NULL; } @@ -3065,28 +3387,28 @@ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_a return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_metal_buffer_i, ctx, size); } -LM_GGML_CALL static size_t lm_ggml_backend_metal_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { +static size_t lm_ggml_backend_metal_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { return 32; UNUSED(buft); } -LM_GGML_CALL static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buffer_type_t buft) { - id device = lm_ggml_backend_metal_get_device(); - size_t max_size = device.maxBufferLength; - lm_ggml_backend_metal_free_device(); +static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buffer_type_t buft) { + id device = lm_ggml_backend_metal_device_acq(buft->device->context); + const size_t max_size = device.maxBufferLength; + lm_ggml_backend_metal_device_rel(buft->device->context); return max_size; UNUSED(buft); } -LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { +static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { return true; UNUSED(buft); } -LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) { +lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) { static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_metal = { /* .iface = */ { /* .get_name = */ lm_ggml_backend_metal_buffer_type_get_name, @@ -3096,16 +3418,39 @@ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(voi /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes /* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host, }, + /* .device = */ &g_lm_ggml_backend_metal_device, /* .context = */ NULL, }; return &lm_ggml_backend_buffer_type_metal; } -// buffer from ptr +static const char * lm_ggml_backend_metal_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; + + UNUSED(buft); +} + +static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_from_ptr_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_from_ptr_type_metal = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_metal_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_metal_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_metal_buffer_type_get_alignment, + /* .get_max_size = */ lm_ggml_backend_metal_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host, + }, + /* .device = */ &g_lm_ggml_backend_metal_device, + /* .context = */ NULL, + }; + + return &lm_ggml_backend_buffer_from_ptr_type_metal; +} -LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct lm_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct lm_ggml_backend_metal_buffer_context)); +// TODO: obsoleted by lm_ggml_backend_metal_device_buffer_from_ptr +lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { + struct lm_ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_buffer_context)); ctx->all_data = data; ctx->all_size = size; @@ -3126,18 +3471,21 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void size_aligned += (size_page - (size_aligned % size_page)); } - id device = lm_ggml_backend_metal_get_device(); + id device = lm_ggml_backend_metal_device_acq(&g_lm_ggml_ctx_dev_main); // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - if (ctx->buffers[ctx->n_buffers].metal == nil) { - LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; + if (ctx->buffers[ctx->n_buffers].metal == nil) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } } lm_ggml_backend_metal_log_allocated_size(device, size_aligned); @@ -3153,71 +3501,116 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void for (size_t i = 0; i < size; i += size_step) { const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - if (ctx->buffers[ctx->n_buffers].metal == nil) { - LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; + if (ctx->buffers[ctx->n_buffers].metal == nil) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } } lm_ggml_backend_metal_log_allocated_size(device, size_step_aligned); if (i + size_step < size) { - LM_GGML_METAL_LOG_INFO("\n"); + LM_GGML_LOG_INFO("\n"); } ++ctx->n_buffers; } } - return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size); + return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size); } // backend -LM_GGML_CALL static const char * lm_ggml_backend_metal_name(lm_ggml_backend_t backend) { +static const char * lm_ggml_backend_metal_name(lm_ggml_backend_t backend) { return "Metal"; UNUSED(backend); } -LM_GGML_CALL static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) { - struct lm_ggml_backend_metal_context * ctx = (struct lm_ggml_backend_metal_context *)backend->context; +static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) { + struct lm_ggml_backend_metal_context * ctx = backend->context; + struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + lm_ggml_backend_metal_device_rel(ctx_dev); lm_ggml_metal_free(ctx); + free(backend); } -LM_GGML_CALL static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_get_default_buffer_type(lm_ggml_backend_t backend) { - return lm_ggml_backend_metal_buffer_type(); - - UNUSED(backend); +static enum lm_ggml_status lm_ggml_backend_metal_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { + return lm_ggml_metal_graph_compute(backend, cgraph); } -LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_metal_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { - struct lm_ggml_backend_metal_context * metal_ctx = (struct lm_ggml_backend_metal_context *)backend->context; +static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb) { + LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend)); - return lm_ggml_metal_graph_compute(metal_ctx, cgraph); -} + struct lm_ggml_backend_metal_context * ctx = (struct lm_ggml_backend_metal_context *)backend->context; -LM_GGML_CALL static bool lm_ggml_backend_metal_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) { - struct lm_ggml_backend_metal_context * metal_ctx = (struct lm_ggml_backend_metal_context *)backend->context; + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, LM_GGML_METAL_MAX_COMMAND_BUFFERS); - return lm_ggml_metal_supports_op(metal_ctx, op); -} + if (ctx->n_cb > 2) { + LM_GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } -LM_GGML_CALL static bool lm_ggml_backend_metal_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name; + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } - UNUSED(backend); + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + for (int idx = node_start; idx < node_end; ++idx) { + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; + } + + lm_ggml_metal_encode_node(backend, idx, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }); } static struct lm_ggml_backend_i lm_ggml_backend_metal_i = { /* .get_name = */ lm_ggml_backend_metal_name, /* .free = */ lm_ggml_backend_metal_free, - /* .get_default_buffer_type = */ lm_ggml_backend_metal_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, @@ -3227,56 +3620,43 @@ LM_GGML_CALL static bool lm_ggml_backend_metal_supports_buft(lm_ggml_backend_t b /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ lm_ggml_backend_metal_graph_compute, - /* .supports_op = */ lm_ggml_backend_metal_supports_op, - /* .supports_buft = */ lm_ggml_backend_metal_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, }; -void lm_ggml_backend_metal_log_set_callback(lm_ggml_log_callback log_callback, void * user_data) { - lm_ggml_metal_log_callback = log_callback; - lm_ggml_metal_log_user_data = user_data; -} - static lm_ggml_guid_t lm_ggml_backend_metal_guid(void) { static lm_ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; return &guid; } +// TODO: remove in the future lm_ggml_backend_t lm_ggml_backend_metal_init(void) { - struct lm_ggml_backend_metal_context * ctx = lm_ggml_metal_init(LM_GGML_DEFAULT_N_THREADS); + lm_ggml_backend_dev_t dev = lm_ggml_backend_reg_dev_get(lm_ggml_backend_metal_reg(), 0); + + struct lm_ggml_backend_metal_context * ctx = lm_ggml_metal_init(dev); if (ctx == NULL) { - LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + LM_GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return NULL; } - lm_ggml_backend_t metal_backend = malloc(sizeof(struct lm_ggml_backend)); + lm_ggml_backend_t backend = malloc(sizeof(struct lm_ggml_backend)); - *metal_backend = (struct lm_ggml_backend) { + *backend = (struct lm_ggml_backend) { /* .guid = */ lm_ggml_backend_metal_guid(), /* .interface = */ lm_ggml_backend_metal_i, + /* .device = */ dev, /* .context = */ ctx, }; - return metal_backend; + lm_ggml_backend_metal_set_n_cb(backend, 1); + + return backend; } bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend) { return backend != NULL && lm_ggml_guid_matches(backend->guid, lm_ggml_backend_metal_guid()); } -void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb) { - LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend)); - - struct lm_ggml_backend_metal_context * ctx = (struct lm_ggml_backend_metal_context *)backend->context; - - ctx->n_cb = MIN(n_cb, LM_GGML_METAL_MAX_BUFFERS); -} - void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data) { LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend)); @@ -3289,23 +3669,258 @@ void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family) { LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend)); - struct lm_ggml_backend_metal_context * ctx = (struct lm_ggml_backend_metal_context *)backend->context; + struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context; - return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; + return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend) { LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend)); struct lm_ggml_backend_metal_context * ctx = (struct lm_ggml_backend_metal_context *)backend->context; - ctx->should_capture_next_compute = true; + ctx->capture_next_compute = true; +} + +// backend device + +static const char * lm_ggml_backend_metal_device_get_name(lm_ggml_backend_dev_t dev) { + return "Metal"; + + LM_GGML_UNUSED(dev); +} + +static const char * lm_ggml_backend_metal_device_get_description(lm_ggml_backend_dev_t dev) { + // acq/rel just to populate ctx->name in case it hasn't been done yet + struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context; + lm_ggml_backend_metal_device_acq(ctx_dev); + lm_ggml_backend_metal_device_rel(ctx_dev); + + return ctx_dev->name; +} + +static void lm_ggml_backend_metal_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) { + if (@available(macOS 10.12, iOS 16.0, *)) { + struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context; + id device = lm_ggml_backend_metal_device_acq(ctx_dev); + + *total = device.recommendedMaxWorkingSetSize; + *free = *total - device.currentAllocatedSize; + + lm_ggml_backend_metal_device_rel(ctx_dev); + } else { + *free = 1; + *total = 1; + } +} + +static enum lm_ggml_backend_dev_type lm_ggml_backend_metal_device_get_type(lm_ggml_backend_dev_t dev) { + return LM_GGML_BACKEND_DEVICE_TYPE_GPU; + + LM_GGML_UNUSED(dev); +} + +static void lm_ggml_backend_metal_device_get_props(lm_ggml_backend_dev_t dev, struct lm_ggml_backend_dev_props * props) { + props->name = lm_ggml_backend_metal_device_get_name(dev); + props->description = lm_ggml_backend_metal_device_get_description(dev); + props->type = lm_ggml_backend_metal_device_get_type(dev); + lm_ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = (struct lm_ggml_backend_dev_caps) { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; } -LM_GGML_CALL lm_ggml_backend_t lm_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning +static lm_ggml_backend_t lm_ggml_backend_metal_device_init(lm_ggml_backend_dev_t dev, const char * params) { + struct lm_ggml_backend_metal_context * ctx = lm_ggml_metal_init(dev); + if (ctx == NULL) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + lm_ggml_backend_t backend = malloc(sizeof(struct lm_ggml_backend)); + + *backend = (struct lm_ggml_backend) { + /* .guid = */ lm_ggml_backend_metal_guid(), + /* .interface = */ lm_ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + lm_ggml_backend_metal_set_n_cb(backend, 1); -LM_GGML_CALL lm_ggml_backend_t lm_ggml_backend_reg_metal_init(const char * params, void * user_data) { - return lm_ggml_backend_metal_init(); + return backend; LM_GGML_UNUSED(params); - LM_GGML_UNUSED(user_data); +} + +static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_device_get_buffer_type(lm_ggml_backend_dev_t dev) { + return lm_ggml_backend_metal_buffer_type(); + + LM_GGML_UNUSED(dev); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + struct lm_ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_buffer_context)); + + ctx->all_data = ptr; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) ptr % size_page; + ptr = (void *) ((char *) ptr - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context; + id device = lm_ggml_backend_metal_device_acq(ctx_dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = ptr; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + } + + lm_ggml_backend_metal_log_allocated_size(device, size_aligned); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + } + + lm_ggml_backend_metal_log_allocated_size(device, size_step_aligned); + + if (i + size_step < size) { + LM_GGML_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size); +} + +static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { + struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context; + + return lm_ggml_metal_supports_op(ctx_dev, op); +} + +static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name; + + UNUSED(dev); +} + +static bool lm_ggml_backend_metal_device_offload_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { + return false; + + LM_GGML_UNUSED(dev); + LM_GGML_UNUSED(op); +} + +static struct lm_ggml_backend_device_i lm_ggml_backend_metal_device_i = { + /* .get_name = */ lm_ggml_backend_metal_device_get_name, + /* .get_description = */ lm_ggml_backend_metal_device_get_description, + /* .get_memory = */ lm_ggml_backend_metal_device_get_memory, + /* .get_type = */ lm_ggml_backend_metal_device_get_type, + /* .get_props = */ lm_ggml_backend_metal_device_get_props, + /* .init_backend = */ lm_ggml_backend_metal_device_init, + /* .get_buffer_type = */ lm_ggml_backend_metal_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ lm_ggml_backend_metal_device_buffer_from_ptr, + /* .supports_op = */ lm_ggml_backend_metal_device_supports_op, + /* .supports_buft = */ lm_ggml_backend_metal_device_supports_buft, + /* .offload_op = */ lm_ggml_backend_metal_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend registry + +static const char * lm_ggml_backend_metal_reg_get_name(lm_ggml_backend_reg_t reg) { + return "Metal"; + + LM_GGML_UNUSED(reg); +} + +static size_t lm_ggml_backend_metal_reg_device_count(lm_ggml_backend_reg_t reg) { + return 1; + + LM_GGML_UNUSED(reg); +} + +static lm_ggml_backend_dev_t lm_ggml_backend_metal_reg_device_get(lm_ggml_backend_reg_t reg, size_t index) { + LM_GGML_ASSERT(index == 0); + + return &g_lm_ggml_backend_metal_device; + + LM_GGML_UNUSED(reg); + LM_GGML_UNUSED(index); +} + +static struct lm_ggml_backend_reg_i lm_ggml_backend_metal_reg_i = { + /* .get_name = */ lm_ggml_backend_metal_reg_get_name, + /* .device_count = */ lm_ggml_backend_metal_reg_device_count, + /* .device_get = */ lm_ggml_backend_metal_reg_device_get, + /* .get_proc_address = */ NULL, +}; + +lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) { + // TODO: make this thread-safe somehow? + { + g_lm_ggml_backend_metal_reg = (struct lm_ggml_backend_reg) { + /* .iface = */ lm_ggml_backend_metal_reg_i, + /* .context = */ NULL, + }; + + g_lm_ggml_backend_metal_device = (struct lm_ggml_backend_device) { + /* .iface = */ lm_ggml_backend_metal_device_i, + /* .reg = */ &g_lm_ggml_backend_metal_reg, + /* .context = */ &g_lm_ggml_ctx_dev_main, + }; + } + + return &g_lm_ggml_backend_metal_reg; } diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 96eec35..6c7a582 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -3,6 +3,7 @@ #include "ggml-quants.h" #include "ggml-impl.h" +#include "ggml-cpu-impl.h" #include @@ -230,6 +231,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) return _mm_packus_epi16( bytes1, bytes2); } + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} #endif #elif defined(__SSSE3__) // horizontally add 4x4 floats @@ -1630,7 +1637,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6 // ===================== Helper functions // static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); + assert(fabsf(fval) <= 4194303.f); float val = fval + 12582912.f; int i; memcpy(&i, &val, sizeof(int)); return (i & 0x007fffff) - 0x00400000; @@ -3306,6 +3313,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + // 5 elements per byte, along 32 bytes + for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*32; + } + // along 16 bytes + for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*16; + } + // 4 elements per byte + for (size_t j = 0; j < sizeof(y->qh); ++j) { + uint8_t q = 0; + for (size_t m = 0; m < 4; ++m) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1; + q *= 3; + q += xi; + } + // shift the first value to the most significant trit + q *= 3; + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qh[j] = q; + } + x += 4*sizeof(y->qh); + } +} + +void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + for (size_t j = 0; j < sizeof(y->qs); j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 4; ++n) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[m + n*32] * id) + 1; + q += (xi & 3) << (2*n); + } + y[i].qs[j + m] = q; + } + x += 4*32; + } + } +} + +void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq1_0 * restrict y = vy; + quantize_row_tq1_0_ref(x, y, k); +} + +void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq2_0 * restrict y = vy; + quantize_row_tq2_0_ref(x, y, k); +} + +size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ1_0, n_per_row); + quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ2_0, n_per_row); + quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + + +void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = LM_GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + + for (size_t n = 0; n < 4; ++n) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } +} + +void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = LM_GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t m = 0; m < 32; ++m) { + int8_t q = (x[i].qs[j + m] >> (l*2)) & 3; + *y++ = (float) (q - 1) * d; + } + } + } + } +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -3644,7 +3836,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q8_K_ref(x, y, k); } -//===================================== Dot ptoducts ================================= +//===================================== Dot products ================================= // // Helper functions @@ -3818,42 +4010,141 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (lm_ggml_sve_cnt_b == QK8_0) { - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + // VLA Implementation using switch case + switch (vector_length) { + case 128: + { + // predicate for activating higher lanes for 4 float32 elements + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); + + // sub 8 + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); + + // load y + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); + + // dot product + sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx0ls, qy0l), + svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx1ls, qy1l), + svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 256: + { + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t pl16 = svnot_b_z(ph32, ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; } + #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -3922,37 +4213,37 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void sumf = hsum_float_8(acc); #elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); + const __m128i mone = _mm_set1_epi16(1); - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); + const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); + const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); + const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); } - sumf = hsum_float_8(acc); + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); #elif defined(__SSSE3__) // set constants const __m128i lowMask = _mm_set1_epi8(0xF); @@ -5303,29 +5594,124 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (lm_ggml_sve_cnt_b == QK8_0) { - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + //VLA Implemenation for SVE + switch (vector_length) { + case 128: + { + // predicate for activating lanes for 16 Int8 elements + const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); + const svbool_t pl16 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); + const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); + const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); + const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); + const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); + + sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), + svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), + svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); + } break; + case 256: + { + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating high 256 bit + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + // predicate for activating high lanes for 8 float32 elements + const svbool_t ph8 = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ph32, x0->qs); + svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); + + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ph32, y0->qs); + svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); + + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + const float32_t deq1 = LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d); + const float32_t deq2 = LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); + + const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + } + default: + assert(false && "Unsupported vector length"); + break; } #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -5470,6 +5856,501 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void *s = sumf; } +void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + const uint8x16_t shift = vld1q_u8(k_shift); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + // first 32 bytes of 5 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); + uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); + uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); + uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); + int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); + int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); + const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); + const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); + sumi0 = vdotq_s32(sumi0, sqx8, qy8); + sumi1 = vdotq_s32(sumi1, sqx9, qy9); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); + uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); + qx5 = vmulq_u8(qx5, shift); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + + // first 32 bytes of 5 elements + { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); + // 8-bit multiplies with shifts, masks and adds + __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 + __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 + __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 + __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 + + // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? + + // Cancel the +1 from avg so that it behaves like a halving add + qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); + qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); + qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); + qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); + qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); + qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); + qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); + qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); + qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); + qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); + qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); + const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + sumi2 = _mm256_add_epi16(sumi2, qx4); + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); + __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 + __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 + __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 + __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 + __m256i qx01 = MM256_SET_M128I(qx1, qx0); + __m256i qx23 = MM256_SET_M128I(qx3, qx2); + + // avx2 does not have 8-bit multiplies, so 16-bit it is. + qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); + qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); + __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); + + __m256i qx45 = MM256_SET_M128I(qx5, qx4); + + // Cancel the +1 from avg so that it behaves like a halving add + qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); + qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); + qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); + qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); + qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); + qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); + qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); + qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); + + const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); + const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); + const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); + + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); + + sumi0 = _mm256_add_epi16(sumi0, qx01); + sumi1 = _mm256_add_epi16(sumi1, qx23); + sumi2 = _mm256_add_epi16(sumi2, qx45); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + uint8x16_t qx0 = vld1q_u8(x[i].qs + j); + uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); + uint8x16_t qx2 = vshrq_n_u8(qx0, 2); + uint8x16_t qx3 = vshrq_n_u8(qx1, 2); + uint8x16_t qx4 = vshrq_n_u8(qx0, 4); + uint8x16_t qx5 = vshrq_n_u8(qx1, 4); + uint8x16_t qx6 = vshrq_n_u8(qx0, 6); + uint8x16_t qx7 = vshrq_n_u8(qx1, 6); + + int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums, because 256*127 still fits + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); + __m256i qx1 = _mm256_srli_epi16(qx0, 2); + __m256i qx2 = _mm256_srli_epi16(qx0, 4); + __m256i qx3 = _mm256_srli_epi16(qx0, 6); + + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -10945,15 +11826,6 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi #endif } - -#if defined(__AVX__) -static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { - const __m128i ax = _mm_sign_epi8(x, x); - const __m128i sy = _mm_sign_epi8(y, x); - return _mm_maddubs_epi16(ax, sy); -} -#endif - #if defined(__AVX2__) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); @@ -14800,6 +15672,14 @@ bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t } } } break; + case LM_GGML_TYPE_TQ1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb); + } break; + case LM_GGML_TYPE_TQ2_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb); + } break; case LM_GGML_TYPE_IQ1_S: { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); diff --git a/cpp/ggml-quants.h b/cpp/ggml-quants.h index 2ac316a..5505824 100644 --- a/cpp/ggml-quants.h +++ b/cpp/ggml-quants.h @@ -26,6 +26,9 @@ void quantize_row_q5_K_ref(const float * LM_GGML_RESTRICT x, block_q5_K * LM_GGM void quantize_row_q6_K_ref(const float * LM_GGML_RESTRICT x, block_q6_K * LM_GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * LM_GGML_RESTRICT x, block_q8_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_tq1_0_ref(const float * LM_GGML_RESTRICT x, block_tq1_0 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0_ref(const float * LM_GGML_RESTRICT x, block_tq2_0 * LM_GGML_RESTRICT y, int64_t k); + void quantize_row_iq3_xxs_ref(const float * LM_GGML_RESTRICT x, block_iq3_xxs * LM_GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl_ref (const float * LM_GGML_RESTRICT x, block_iq4_nl * LM_GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs_ref (const float * LM_GGML_RESTRICT x, block_iq4_xs * LM_GGML_RESTRICT y, int64_t k); @@ -46,6 +49,9 @@ void quantize_row_q5_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT void quantize_row_q6_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_tq1_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + void quantize_row_iq3_xxs(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); @@ -67,6 +73,9 @@ void dequantize_row_q5_K(const block_q5_K * LM_GGML_RESTRICT x, float * LM_GGML_ void dequantize_row_q6_K(const block_q6_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); void dequantize_row_q8_K(const block_q8_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +void dequantize_row_tq1_0(const block_tq1_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +void dequantize_row_tq2_0(const block_tq2_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); + void dequantize_row_iq2_xxs(const block_iq2_xxs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); void dequantize_row_iq2_xs (const block_iq2_xs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); void dequantize_row_iq2_s (const block_iq2_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); @@ -90,6 +99,9 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); + void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); void lm_ggml_vec_dot_iq2_xs_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); void lm_ggml_vec_dot_iq2_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); @@ -111,6 +123,9 @@ size_t quantize_iq4_nl (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTR size_t quantize_iq4_xs (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq3_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_tq1_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_tq2_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + size_t quantize_q2_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q3_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -127,10 +142,6 @@ void iq2xs_free_impl(enum lm_ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); -#if defined(__ARM_FEATURE_SVE) -extern int lm_ggml_sve_cnt_b; -#endif - #ifdef __cplusplus } #endif diff --git a/cpp/ggml.c b/cpp/ggml.c index 1c27476..2f2d767 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1,7 +1,9 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC +#include "ggml-backend.h" #include "ggml-impl.h" +#include "ggml-cpu-impl.h" #include "ggml-quants.h" #include "ggml.h" #include "ggml-aarch64.h" @@ -33,13 +35,6 @@ #include #endif -#ifdef LM_GGML_USE_METAL -#include -#endif - -#if defined(__ARM_FEATURE_SVE) -int lm_ggml_sve_cnt_b = 0; -#endif #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef LM_GGML_USE_LLAMAFILE #endif @@ -61,6 +56,25 @@ int lm_ggml_sve_cnt_b = 0; #pragma warning(disable: 4702) #endif +// Note: once we move threading into a separate C++ file +// will use std::hardware_destructive_interference_size instead of hardcoding it here +// and we'll use C++ attribute syntax. +#define LM_GGML_CACHE_LINE 64 + +#if defined(__clang__) || defined(__GNUC__) +#define LM_GGML_CACHE_ALIGN __attribute__((aligned(LM_GGML_CACHE_LINE))) +#endif + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define LM_GGML_TSAN_ENABLED 1 +#endif +#else // __has_feature +#if defined(__SANITIZE_THREAD__) +#define LM_GGML_TSAN_ENABLED 1 +#endif +#endif // __has_feature + #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -69,23 +83,44 @@ int lm_ggml_sve_cnt_b = 0; #endif #include +#if !defined(__clang__) +#define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE)) + typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; typedef atomic_int atomic_flag; #define ATOMIC_FLAG_INIT 0 +typedef enum { + memory_order_relaxed, + memory_order_consume, + memory_order_acquire, + memory_order_release, + memory_order_acq_rel, + memory_order_seq_cst +} memory_order; + static void atomic_store(atomic_int * ptr, LONG val) { InterlockedExchange(ptr, val); } +static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { + // TODO: add support for explicit memory order + InterlockedExchange(ptr, val); +} static LONG atomic_load(atomic_int * ptr) { return InterlockedCompareExchange(ptr, 0, 0); } +static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedCompareExchange(ptr, 0, 0); +} static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { return InterlockedExchangeAdd(ptr, inc); } -static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); +static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedExchangeAdd(ptr, inc); } static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { return InterlockedExchange(ptr, 1); @@ -93,6 +128,12 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { static void atomic_flag_clear(atomic_flag * ptr) { InterlockedExchange(ptr, 0); } +static void atomic_thread_fence(memory_order mo) { + MemoryBarrier(); +} +#else // clang +#include +#endif typedef HANDLE pthread_t; @@ -121,8 +162,13 @@ static int sched_yield (void) { return 0; } #else + #include #include +#include +#if defined(__FreeBSD__) +#include +#endif typedef void * thread_ret_t; @@ -139,6 +185,8 @@ typedef pthread_t lm_ggml_thread_t; #endif #if defined(__APPLE__) +#include +#include #include #endif @@ -258,6 +306,7 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) { } #define LM_GGML_DEBUG 0 + #define LM_GGML_GELU_FP16 #define LM_GGML_GELU_QUICK_FP16 @@ -269,26 +318,64 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) { // logging // +struct lm_ggml_logger_state { + lm_ggml_log_callback log_callback; + void * log_callback_user_data; +}; +static struct lm_ggml_logger_state g_logger_state = {lm_ggml_log_callback_default, NULL}; + +static void lm_ggml_log_internal_v(enum lm_ggml_log_level level, const char * format, va_list args) { + if (format == NULL) { + return; + } + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = (char *) calloc(len + 1, sizeof(char)); + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + free(buffer2); + } + va_end(args_copy); +} + +void lm_ggml_log_internal(enum lm_ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + lm_ggml_log_internal_v(level, format, args); + va_end(args); +} + +void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + #if (LM_GGML_DEBUG >= 1) -#define LM_GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#define LM_GGML_PRINT_DEBUG(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) #else #define LM_GGML_PRINT_DEBUG(...) #endif #if (LM_GGML_DEBUG >= 5) -#define LM_GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#define LM_GGML_PRINT_DEBUG_5(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) #else #define LM_GGML_PRINT_DEBUG_5(...) #endif #if (LM_GGML_DEBUG >= 10) -#define LM_GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#define LM_GGML_PRINT_DEBUG_10(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) #else #define LM_GGML_PRINT_DEBUG_10(...) #endif -#define LM_GGML_PRINT(...) printf(__VA_ARGS__) - // // end of logging block // @@ -299,22 +386,40 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) { //#define LM_GGML_SOFT_MAX_ACCELERATE #endif + +void * lm_ggml_aligned_malloc(size_t size) { #if defined(_MSC_VER) || defined(__MINGW32__) -#define LM_GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, LM_GGML_MEM_ALIGN) -#define LM_GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) + return _aligned_malloc(size, TENSOR_ALIGNMENT); #else -inline static void * lm_ggml_aligned_malloc(size_t size) { if (size == 0) { - LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_aligned_malloc!\n"); + LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_aligned_malloc!\n"); return NULL; } void * aligned_memory = NULL; #ifdef LM_GGML_USE_CPU_HBM - int result = hbw_posix_memalign(&aligned_memory, 16, size); + int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); +#elif TARGET_OS_OSX + kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE); + int result = EFAULT; + switch (alloc_status) { + case KERN_SUCCESS: + result = 0; + break; + case KERN_INVALID_ADDRESS: + result = EINVAL; + break; + case KERN_NO_SPACE: + result = ENOMEM; + break; + default: + result = EFAULT; + break; + } #elif LM_GGML_USE_METAL - int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); + const long page_size = sysconf(_SC_PAGESIZE); + int result = posix_memalign(&aligned_memory, MAX(TENSOR_ALIGNMENT, page_size), size); #else - int result = posix_memalign(&aligned_memory, LM_GGML_MEM_ALIGN, size); + int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); #endif if (result != 0) { // Handle allocation failure @@ -327,28 +432,40 @@ inline static void * lm_ggml_aligned_malloc(size_t size) { error_desc = "insufficient memory"; break; } - LM_GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); + LM_GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); LM_GGML_ABORT("fatal error"); return NULL; } return aligned_memory; +#endif } -#define LM_GGML_ALIGNED_MALLOC(size) lm_ggml_aligned_malloc(size) -#ifdef LM_GGML_USE_CPU_HBM -#define LM_GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr) + +void lm_ggml_aligned_free(void * ptr, size_t size) { + LM_GGML_UNUSED(size); +#if defined(_MSC_VER) || defined(__MINGW32__) + _aligned_free(ptr); +#elif LM_GGML_USE_CPU_HBM + if (ptr != NULL) { + hbw_free(ptr); + } +#elif TARGET_OS_OSX + if (ptr != NULL) { + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size); + } #else -#define LM_GGML_ALIGNED_FREE(ptr) free(ptr) -#endif + free(ptr); #endif +} + inline static void * lm_ggml_malloc(size_t size) { if (size == 0) { - LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_malloc!\n"); + LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_malloc!\n"); return NULL; } void * result = malloc(size); if (result == NULL) { - LM_GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + LM_GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); LM_GGML_ABORT("fatal error"); } return result; @@ -357,12 +474,12 @@ inline static void * lm_ggml_malloc(size_t size) { // calloc inline static void * lm_ggml_calloc(size_t num, size_t size) { if (num == 0 || size == 0) { - LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_calloc!\n"); + LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_calloc!\n"); return NULL; } void * result = calloc(num, size); if (result == NULL) { - LM_GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + LM_GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); LM_GGML_ABORT("fatal error"); } return result; @@ -402,7 +519,16 @@ static lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float lm_ggml_table_f32_f16[1 << 16]; -LM_GGML_CALL const char * lm_ggml_status_to_string(enum lm_ggml_status status) { +#if defined(__ARM_ARCH) +struct lm_ggml_arm_arch_features_type { + int has_neon; + int has_i8mm; + int has_sve; + int sve_cnt; +} lm_ggml_arm_arch_features = {-1, -1, -1, 0}; +#endif + +const char * lm_ggml_status_to_string(enum lm_ggml_status status) { switch (status) { case LM_GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; case LM_GGML_STATUS_FAILED: return "GGML status: error (operation failed)"; @@ -633,7 +759,7 @@ static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const floa static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc); static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc); -static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { +static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { [LM_GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -1027,13 +1153,37 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .ncols = 8, .gemv = lm_ggml_gemv_q4_0_8x8_q8_0, .gemm = lm_ggml_gemm_q4_0_8x8_q8_0, - } + }, + [LM_GGML_TYPE_TQ1_0] = { + .type_name = "tq1_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq1_0), + .is_quantized = true, + .to_float = (lm_ggml_to_float_t) dequantize_row_tq1_0, + .from_float = quantize_row_tq1_0, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq1_0_ref, + .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_TQ2_0] = { + .type_name = "tq2_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq2_0), + .is_quantized = true, + .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0, + .from_float = quantize_row_tq2_0, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref, + .vec_dot = lm_ggml_vec_dot_tq2_0_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use -lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) { +const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type) { LM_GGML_ASSERT(type < LM_GGML_TYPE_COUNT); - return type_traits[type]; + return &type_traits[type]; } // @@ -1069,21 +1219,21 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) { #define LM_GGML_F32x4_ADD vaddq_f32 #define LM_GGML_F32x4_MUL vmulq_f32 #define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - res = LM_GGML_F32x4_REDUCE_ONE(x[0]); \ +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \ } #define LM_GGML_F32_VEC LM_GGML_F32x4 @@ -1110,30 +1260,30 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) { #define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) #define LM_GGML_F16x8_ADD vaddq_f16 #define LM_GGML_F16x8_MUL vmulq_f16 - #define LM_GGML_F16x8_REDUCE(res, x) \ - do { \ - int offset = LM_GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + #define LM_GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = LM_GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ } while (0) #define LM_GGML_F16_VEC LM_GGML_F16x8 #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1 #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p) - #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), r[i]) + #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), (r)[i]) #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL @@ -1842,24 +1992,37 @@ static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) { #define LM_GGML_F16_ARR (LM_GGML_F16_STEP/LM_GGML_F16_EPR) #endif +// +// ggml object +// + +struct lm_ggml_object { + size_t offs; + size_t size; + + struct lm_ggml_object * next; + + enum lm_ggml_object_type type; + + char padding[4]; +}; + +static const size_t LM_GGML_OBJECT_SIZE = sizeof(struct lm_ggml_object); + // // ggml context // struct lm_ggml_context { size_t mem_size; - void* mem_buffer; + void * mem_buffer; bool mem_buffer_owned; bool no_alloc; - bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers int n_objects; struct lm_ggml_object * objects_begin; struct lm_ggml_object * objects_end; - - struct lm_ggml_scratch scratch; - struct lm_ggml_scratch scratch_save; }; struct lm_ggml_context_container { @@ -1868,28 +2031,103 @@ struct lm_ggml_context_container { struct lm_ggml_context context; }; -struct lm_ggml_compute_state_shared { - const struct lm_ggml_cgraph * cgraph; - const struct lm_ggml_cplan * cplan; +// +// Threading defs +// + +typedef pthread_t lm_ggml_thread_t; + +#if defined(_WIN32) + +typedef CONDITION_VARIABLE lm_ggml_cond_t; +typedef SRWLOCK lm_ggml_mutex_t; + +#define lm_ggml_mutex_init(m) InitializeSRWLock(m) +#define lm_ggml_mutex_destroy(m) +#define lm_ggml_mutex_lock(m) AcquireSRWLockExclusive(m) +#define lm_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) +#define lm_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) +#define lm_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) + +#define lm_ggml_cond_init(c) InitializeConditionVariable(c) +#define lm_ggml_cond_destroy(c) +#define lm_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) +#define lm_ggml_cond_broadcast(c) WakeAllConditionVariable(c) + +#define lm_ggml_thread_create pthread_create +#define lm_ggml_thread_join pthread_join + +#else + +typedef pthread_cond_t lm_ggml_cond_t; +typedef pthread_mutex_t lm_ggml_mutex_t; + +#define lm_ggml_mutex_init(m) pthread_mutex_init(m, NULL) +#define lm_ggml_mutex_destroy(m) pthread_mutex_destroy(m) +#define lm_ggml_mutex_lock(m) pthread_mutex_lock(m) +#define lm_ggml_mutex_unlock(m) pthread_mutex_unlock(m) +#define lm_ggml_mutex_lock_shared(m) pthread_mutex_lock(m) +#define lm_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) + +#define lm_ggml_lock_init(x) UNUSED(x) +#define lm_ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define lm_ggml_lock_lock(x) _mm_pause() +#else +#define lm_ggml_lock_lock(x) UNUSED(x) +#endif +#define lm_ggml_lock_unlock(x) UNUSED(x) + +#define LM_GGML_LOCK_INITIALIZER 0 +#define lm_ggml_cond_init(c) pthread_cond_init(c, NULL) +#define lm_ggml_cond_destroy(c) pthread_cond_destroy(c) +#define lm_ggml_cond_wait(c, m) pthread_cond_wait(c, m) +#define lm_ggml_cond_broadcast(c) pthread_cond_broadcast(c) + +#define lm_ggml_thread_create pthread_create +#define lm_ggml_thread_join pthread_join - int n_threads; +#endif + +// Threadpool def +struct lm_ggml_threadpool { + lm_ggml_mutex_t mutex; // mutex for cond.var + lm_ggml_cond_t cond; // cond.var for waiting for new work + + struct lm_ggml_cgraph * cgraph; + struct lm_ggml_cplan * cplan; // synchronization primitives - atomic_int n_barrier; - atomic_int n_barrier_passed; + atomic_int n_graph; // incremented when there is work to be done (i.e each graph) + atomic_int LM_GGML_CACHE_ALIGN n_barrier; + atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed; + atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + + // these are atomic as an annotation for thread-sanitizer + atomic_bool stop; // Used for stopping the threadpool altogether + atomic_bool pause; // Used for pausing the threadpool or individual threads + atomic_bool abort; // Used for aborting processing of a graph - lm_ggml_abort_callback abort_callback; // abort lm_ggml_graph_compute when true - void * abort_callback_data; + struct lm_ggml_compute_state * workers; // per thread state + int n_threads_max; // number of threads in the pool + atomic_int n_threads_cur; // number of threads used in the current graph - atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads + int32_t prio; // Scheduling priority + uint32_t poll; // Polling level (0 - no polling) enum lm_ggml_status ec; }; +// Per-thread state struct lm_ggml_compute_state { +#ifndef LM_GGML_USE_OPENMP lm_ggml_thread_t thrd; + bool cpumask[LM_GGML_MAX_N_THREADS]; + int last_graph; + bool pending; +#endif + struct lm_ggml_threadpool * threadpool; int ith; - struct lm_ggml_compute_state_shared * shared; }; struct lm_ggml_compute_params { @@ -1900,7 +2138,7 @@ struct lm_ggml_compute_params { size_t wsize; void * wdata; - struct lm_ggml_compute_state_shared * shared; + struct lm_ggml_threadpool * threadpool; }; // @@ -2310,7 +2548,9 @@ inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void lm_ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void lm_ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } inline static void lm_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void lm_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } @@ -2322,6 +2562,7 @@ inline static void lm_ggml_vec_sigmoid_f32 (const int n, float * y, const float // TODO: optimize performance inline static void lm_ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void lm_ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void lm_ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -2669,6 +2910,19 @@ static lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const floa return sum; } +static lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + lm_ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (lm_ggml_float)expf(val); + } + return sum = (lm_ggml_float)logf(sum); +} + inline static float lm_ggml_silu_backward_f32(float x, float dy) { const float s = 1.0f/(1.0f + expf(-x)); return dy*s*(1.0f + x*(1.0f - s)); @@ -2760,10 +3014,13 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "SQR", "SQRT", "LOG", + "SIN", + "COS", "SUM", "SUM_ROWS", "MEAN", "ARGMAX", + "COUNT_EQUAL", "REPEAT", "REPEAT_BACK", "CONCAT", @@ -2797,9 +3054,11 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", + "IM2COL_BACK", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", + "POOL_2D_BACK", "UPSCALE", "PAD", "ARANGE", @@ -2815,6 +3074,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", + "RWKV_WKV", "UNARY", @@ -2831,9 +3091,10 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAMW", }; -static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74"); +static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "none", @@ -2848,10 +3109,13 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "x^2", "√x", "log(x)", + "sin(x)", + "cos(x)", "Σx", "Σx_k", "Σx/n", "argmax(x)", + "count_equal(x)", "repeat(x)", "repeat_back(x)", "concat(x, y)", @@ -2885,9 +3149,11 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", + "im2col_back(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", + "pool_2d_back(x)", "upscale(x)", "pad(x)", "arange(start, stop, step)", @@ -2903,6 +3169,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", + "rwkv_wkv(k, v, r, tf, td, s)", "unary(x)", @@ -2919,9 +3186,10 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + "adamw(x)", }; -static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74"); +static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); @@ -2940,14 +3208,28 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = { "SILU", "HARDSWISH", "HARDSIGMOID", + "EXP", }; -static_assert(LM_GGML_UNARY_OP_COUNT == 13, "LM_GGML_UNARY_OP_COUNT != 13"); +static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14"); static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN"); static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN"); +// Helpers for polling loops +#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) +static inline void lm_ggml_thread_cpu_relax(void) { + __asm__ volatile("yield" ::: "memory"); +} +#elif defined(__x86_64__) +static inline void lm_ggml_thread_cpu_relax(void) { + _mm_pause(); +} +#else +static inline void lm_ggml_thread_cpu_relax(void) {;} +#endif + // // NUMA support // @@ -2978,7 +3260,6 @@ struct lm_ggml_numa_nodes { // struct lm_ggml_state { - struct lm_ggml_context_container contexts[LM_GGML_MAX_CONTEXTS]; struct lm_ggml_numa_nodes numa; }; @@ -2994,47 +3275,43 @@ inline static void lm_ggml_critical_section_start(void) { } } -#ifdef LM_GGML_USE_OPENMP -static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) { - if (shared->n_threads == 1) { +static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) { + int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); + if (n_threads == 1) { return; } +#ifdef LM_GGML_USE_OPENMP #pragma omp barrier -} #else -static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) { - if (shared->n_threads == 1) { - return; - } - - atomic_int * n_barrier = &shared->n_barrier; - atomic_int * n_barrier_passed = &shared->n_barrier_passed; + int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed); - int n_threads = shared->n_threads; - int passed_old = atomic_load(n_barrier_passed); + // enter barrier (full seq-cst fence) + int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst); - if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) { + if (n_barrier == (n_threads - 1)) { // last thread - atomic_store(n_barrier, 0); - atomic_fetch_add(n_barrier_passed, 1); - } else { - // wait for other threads - const int n_spin_before_sleep = 100000; - while (true) { - for (int i = 0; i < n_spin_before_sleep; i++) { - if (atomic_load(n_barrier_passed) != passed_old) { - return; - } - #if defined(__SSE3__) - _mm_pause(); - #endif - } - sched_yield(); - } + atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); + + // exit barrier (fill seq-cst fence) + atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); + return; } -} + + // wait for other threads + while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) { + lm_ggml_thread_cpu_relax(); + } + + // exit barrier (full seq-cst fence) + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef LM_GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif #endif +} // TODO: make this somehow automatically executed // some sort of "sentry" mechanism @@ -3134,7 +3411,7 @@ void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) { if (fptr != NULL) { char buf[42]; if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { - LM_GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + LM_GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); } fclose(fptr); } @@ -3152,38 +3429,38 @@ bool lm_ggml_is_numa(void) { //////////////////////////////////////////////////////////////////////////////// void lm_ggml_print_object(const struct lm_ggml_object * obj) { - LM_GGML_PRINT(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + LM_GGML_LOG_INFO(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", obj->type, obj->offs, obj->size, (const void *) obj->next); } void lm_ggml_print_objects(const struct lm_ggml_context * ctx) { struct lm_ggml_object * obj = ctx->objects_begin; - LM_GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + LM_GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx); while (obj != NULL) { lm_ggml_print_object(obj); obj = obj->next; } - LM_GGML_PRINT("%s: --- end ---\n", __func__); + LM_GGML_LOG_INFO("%s: --- end ---\n", __func__); } -LM_GGML_CALL int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) { +int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -LM_GGML_CALL int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) { +int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -LM_GGML_CALL size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) { +size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) { size_t nbytes; - size_t blck_size = lm_ggml_blck_size(tensor->type); + const size_t blck_size = lm_ggml_blck_size(tensor->type); if (blck_size == 1) { nbytes = lm_ggml_type_size(tensor->type); for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { @@ -3204,15 +3481,15 @@ size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) { return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN); } -LM_GGML_CALL int64_t lm_ggml_blck_size(enum lm_ggml_type type) { +int64_t lm_ggml_blck_size(enum lm_ggml_type type) { return type_traits[type].blck_size; } -LM_GGML_CALL size_t lm_ggml_type_size(enum lm_ggml_type type) { +size_t lm_ggml_type_size(enum lm_ggml_type type) { return type_traits[type].type_size; } -LM_GGML_CALL size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) { +size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) { assert(ne % lm_ggml_blck_size(type) == 0); return lm_ggml_type_size(type)*ne/lm_ggml_blck_size(type); } @@ -3221,15 +3498,15 @@ double lm_ggml_type_sizef(enum lm_ggml_type type) { return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; } -LM_GGML_CALL const char * lm_ggml_type_name(enum lm_ggml_type type) { - return type_traits[type].type_name; +const char * lm_ggml_type_name(enum lm_ggml_type type) { + return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; } -LM_GGML_CALL bool lm_ggml_is_quantized(enum lm_ggml_type type) { +bool lm_ggml_is_quantized(enum lm_ggml_type type) { return type_traits[type].is_quantized; } -LM_GGML_CALL const char * lm_ggml_op_name(enum lm_ggml_op op) { +const char * lm_ggml_op_name(enum lm_ggml_op op) { return LM_GGML_OP_NAME[op]; } @@ -3241,7 +3518,7 @@ const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op) { return LM_GGML_UNARY_OP_NAME[op]; } -LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { +const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { if (t->op == LM_GGML_OP_UNARY) { enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(t); return lm_ggml_unary_op_name(uop); @@ -3249,7 +3526,7 @@ LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { return lm_ggml_op_name(t->op); } -LM_GGML_CALL size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { +size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { return lm_ggml_type_size(tensor->type); } @@ -3342,7 +3619,7 @@ size_t lm_ggml_tensor_overhead(void) { return LM_GGML_OBJECT_SIZE + LM_GGML_TENSOR_SIZE; } -LM_GGML_CALL bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } @@ -3368,23 +3645,23 @@ static bool lm_ggml_is_contiguous_n(const struct lm_ggml_tensor * tensor, int n) return true; } -LM_GGML_CALL bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) { return lm_ggml_is_contiguous_0(tensor); } -LM_GGML_CALL bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) { return lm_ggml_is_contiguous_n(tensor, 0); } -LM_GGML_CALL bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) { return lm_ggml_is_contiguous_n(tensor, 1); } -LM_GGML_CALL bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) { return lm_ggml_is_contiguous_n(tensor, 2); } -LM_GGML_CALL bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; @@ -3399,7 +3676,7 @@ static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -LM_GGML_CALL bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) { for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { if (tensor->ne[i] == 0) { // empty if any dimension has no elements @@ -3466,6 +3743,70 @@ static inline int lm_ggml_up(int n, int m) { //////////////////////////////////////////////////////////////////////////////// +#if defined(__ARM_ARCH) + +#if defined(__linux__) && defined(__aarch64__) +#include +#elif defined(__APPLE__) +#include +#endif + +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM 0 +#endif + +static void lm_ggml_init_arm_arch_features(void) { +#if defined(__linux__) && defined(__aarch64__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + lm_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + lm_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + lm_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + +#if defined(__ARM_FEATURE_SVE) + lm_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#endif +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + lm_ggml_arm_arch_features.has_neon = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + lm_ggml_arm_arch_features.has_i8mm = oldp; + + lm_ggml_arm_arch_features.has_sve = 0; + lm_ggml_arm_arch_features.sve_cnt = 0; +#else +// Run-time CPU feature detection not implemented for this platform, fallback to compile time +#if defined(__ARM_NEON) + lm_ggml_arm_arch_features.has_neon = 1; +#else + lm_ggml_arm_arch_features.has_neon = 0; +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + lm_ggml_arm_arch_features.has_i8mm = 1; +#else + lm_ggml_arm_arch_features.has_i8mm = 0; +#endif + +#if defined(__ARM_FEATURE_SVE) + lm_ggml_arm_arch_features.has_sve = 1; + lm_ggml_arm_arch_features.sve_cnt = 16; +#else + lm_ggml_arm_arch_features.has_sve = 0; + lm_ggml_arm_arch_features.sve_cnt = 0; +#endif +#endif +} +#endif + struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { // make this function thread safe lm_ggml_critical_section_start(); @@ -3500,45 +3841,27 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { const uint64_t t_start = lm_ggml_time_us(); UNUSED(t_start); g_state = (struct lm_ggml_state) { - /*.contexts =*/ { { 0 } }, /*.numa =*/ { .n_nodes = 0, .total_cpus = 0, }, }; - for (int i = 0; i < LM_GGML_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - const uint64_t t_end = lm_ggml_time_us(); UNUSED(t_end); LM_GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } - is_first_call = false; - } - - // find non-used context in g_state - struct lm_ggml_context * ctx = NULL; - - for (int i = 0; i < LM_GGML_MAX_CONTEXTS; i++) { - if (!g_state.contexts[i].used) { - g_state.contexts[i].used = true; - ctx = &g_state.contexts[i].context; +#if defined(__ARM_ARCH) + lm_ggml_init_arm_arch_features(); +#endif - LM_GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); - break; - } + is_first_call = false; } - if (ctx == NULL) { - LM_GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); - - lm_ggml_critical_section_end(); + lm_ggml_critical_section_end(); - return NULL; - } + struct lm_ggml_context * ctx = LM_GGML_MALLOC(sizeof(struct lm_ggml_context)); // allow to call lm_ggml_init with 0 size if (params.mem_size == 0) { @@ -3549,79 +3872,49 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { *ctx = (struct lm_ggml_context) { /*.mem_size =*/ mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : LM_GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : lm_ggml_aligned_malloc(mem_size), /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, /*.no_alloc =*/ params.no_alloc, - /*.no_alloc_save =*/ params.no_alloc, /*.n_objects =*/ 0, /*.objects_begin =*/ NULL, /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, }; LM_GGML_ASSERT(ctx->mem_buffer != NULL); LM_GGML_ASSERT_ALIGNED(ctx->mem_buffer); -#if defined(__ARM_FEATURE_SVE) - if (!lm_ggml_sve_cnt_b) { - lm_ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); - } -#endif - LM_GGML_PRINT_DEBUG("%s: context initialized\n", __func__); - lm_ggml_critical_section_end(); - return ctx; } -void lm_ggml_free(struct lm_ggml_context * ctx) { +void lm_ggml_reset(struct lm_ggml_context * ctx) { if (ctx == NULL) { return; } - // make this function thread safe - lm_ggml_critical_section_start(); - - bool found = false; - - for (int i = 0; i < LM_GGML_MAX_CONTEXTS; i++) { - if (&g_state.contexts[i].context == ctx) { - g_state.contexts[i].used = false; - - LM_GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n", - __func__, i, lm_ggml_used_mem(ctx)); - - if (ctx->mem_buffer_owned) { - LM_GGML_ALIGNED_FREE(ctx->mem_buffer); - } + ctx->n_objects = 0; + ctx->objects_begin = NULL; + ctx->objects_end = NULL; +} - found = true; - break; - } +void lm_ggml_free(struct lm_ggml_context * ctx) { + if (ctx == NULL) { + return; } - if (!found) { - LM_GGML_PRINT_DEBUG("%s: context not found\n", __func__); + if (ctx->mem_buffer_owned) { + lm_ggml_aligned_free(ctx->mem_buffer, ctx->mem_size); } - lm_ggml_critical_section_end(); + LM_GGML_FREE(ctx); } size_t lm_ggml_used_mem(const struct lm_ggml_context * ctx) { return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; } -size_t lm_ggml_set_scratch(struct lm_ggml_context * ctx, struct lm_ggml_scratch scratch) { - const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; - - ctx->scratch = scratch; - - return result; -} - bool lm_ggml_get_no_alloc(struct lm_ggml_context * ctx) { return ctx->no_alloc; } @@ -3649,27 +3942,6 @@ size_t lm_ggml_get_max_tensor_size(const struct lm_ggml_context * ctx) { return max_size; } -// IMPORTANT: -// when creating "opt" tensors, always save and load the scratch buffer -// this is an error prone process, but it is necessary to support inplace -// operators when using scratch buffers -// TODO: implement a better way -static void lm_ggml_scratch_save(struct lm_ggml_context * ctx) { - // this is needed to allow opt tensors to store their data - // TODO: again, need to find a better way - ctx->no_alloc_save = ctx->no_alloc; - ctx->no_alloc = false; - - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; -} - -static void lm_ggml_scratch_load(struct lm_ggml_context * ctx) { - ctx->no_alloc = ctx->no_alloc_save; - - ctx->scratch = ctx->scratch_save; -} - //////////////////////////////////////////////////////////////////////////////// static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx, enum lm_ggml_object_type type, size_t size) { @@ -3687,9 +3959,11 @@ static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx, struct lm_ggml_object * const obj_new = (struct lm_ggml_object *)(mem_buffer + cur_end); if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) { - LM_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed, ctx->mem_size); - assert(false); + LM_GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size); +#ifndef NDEBUG + LM_GGML_ABORT("not enough space in the context's memory pool"); +#endif return NULL; } @@ -3748,27 +4022,12 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( size_t obj_alloc_size = 0; if (view_src == NULL && !ctx->no_alloc) { - if (ctx->scratch.data != NULL) { - // allocate tensor data in the scratch buffer - if (ctx->scratch.offs + data_size > ctx->scratch.size) { - LM_GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", - __func__, ctx->scratch.offs + data_size, ctx->scratch.size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - ctx->scratch.offs += data_size; - } else { - // allocate tensor data in the context's memory pool - obj_alloc_size = data_size; - } + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; } struct lm_ggml_object * const obj_new = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_TENSOR, LM_GGML_TENSOR_SIZE + obj_alloc_size); - - // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here + LM_GGML_ASSERT(obj_new); struct lm_ggml_tensor * const result = (struct lm_ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); @@ -3865,24 +4124,16 @@ struct lm_ggml_tensor * lm_ggml_new_tensor_4d( } struct lm_ggml_tensor * lm_ggml_new_i32(struct lm_ggml_context * ctx, int32_t value) { - lm_ggml_scratch_save(ctx); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 1); - lm_ggml_scratch_load(ctx); - lm_ggml_set_i32(result, value); return result; } struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float value) { - lm_ggml_scratch_save(ctx); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, 1); - lm_ggml_scratch_load(ctx); - lm_ggml_set_f32(result, value); return result; @@ -3919,7 +4170,15 @@ static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i } struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) { - memset(tensor->data, 0, lm_ggml_nbytes(tensor)); + if (lm_ggml_is_empty(tensor)) { + return tensor; + } + if (tensor->buffer) { + lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor)); + } else { + LM_GGML_ASSERT(tensor->data); + memset(tensor->data, 0, lm_ggml_nbytes(tensor)); + } return tensor; } @@ -4348,7 +4607,7 @@ float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor) { return (float *)(tensor->data); } -LM_GGML_CALL enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) { +enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) { LM_GGML_ASSERT(tensor->op == LM_GGML_OP_UNARY); return (enum lm_ggml_unary_op) lm_ggml_get_op_params_i32(tensor, 0); } @@ -4445,18 +4704,11 @@ struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const c static struct lm_ggml_tensor * lm_ggml_dup_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct lm_ggml_tensor * a, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_DUP; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_DUP; result->src[0] = a; return result; @@ -4464,13 +4716,13 @@ static struct lm_ggml_tensor * lm_ggml_dup_impl( struct lm_ggml_tensor * lm_ggml_dup( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { + struct lm_ggml_tensor * a) { return lm_ggml_dup_impl(ctx, a, false); } struct lm_ggml_tensor * lm_ggml_dup_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { + struct lm_ggml_tensor * a) { return lm_ggml_dup_impl(ctx, a, true); } @@ -4478,23 +4730,14 @@ struct lm_ggml_tensor * lm_ggml_dup_inplace( static struct lm_ggml_tensor * lm_ggml_add_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - bool inplace) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_ADD; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4503,15 +4746,15 @@ static struct lm_ggml_tensor * lm_ggml_add_impl( struct lm_ggml_tensor * lm_ggml_add( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_add_impl(ctx, a, b, false); } struct lm_ggml_tensor * lm_ggml_add_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_add_impl(ctx, a, b, true); } @@ -4519,9 +4762,9 @@ struct lm_ggml_tensor * lm_ggml_add_inplace( static struct lm_ggml_tensor * lm_ggml_add_cast_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - enum lm_ggml_type type) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + enum lm_ggml_type type) { // TODO: support less-strict constraint // LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); LM_GGML_ASSERT(lm_ggml_can_repeat_rows(b, a)); @@ -4531,18 +4774,9 @@ static struct lm_ggml_tensor * lm_ggml_add_cast_impl( a->type == LM_GGML_TYPE_F16 || a->type == LM_GGML_TYPE_BF16); - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: support backward pass for broadcasting - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); - result->op = LM_GGML_OP_ADD; - result->grad = is_node ? lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, a->ne) : NULL; + result->op = LM_GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4551,9 +4785,9 @@ static struct lm_ggml_tensor * lm_ggml_add_cast_impl( struct lm_ggml_tensor * lm_ggml_add_cast( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - enum lm_ggml_type type) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + enum lm_ggml_type type) { return lm_ggml_add_cast_impl(ctx, a, b, type); } @@ -4561,22 +4795,15 @@ struct lm_ggml_tensor * lm_ggml_add_cast( static struct lm_ggml_tensor * lm_ggml_add1_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - bool inplace) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { LM_GGML_ASSERT(lm_ggml_is_scalar(b)); LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_ADD1; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ADD1; result->src[0] = a; result->src[1] = b; @@ -4585,15 +4812,15 @@ static struct lm_ggml_tensor * lm_ggml_add1_impl( struct lm_ggml_tensor * lm_ggml_add1( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_add1_impl(ctx, a, b, false); } struct lm_ggml_tensor * lm_ggml_add1_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_add1_impl(ctx, a, b, true); } @@ -4601,31 +4828,24 @@ struct lm_ggml_tensor * lm_ggml_add1_inplace( static struct lm_ggml_tensor * lm_ggml_acc_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { LM_GGML_ASSERT(lm_ggml_nelements(b) <= lm_ggml_nelements(a)); LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT(b->type == LM_GGML_TYPE_F32); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_ACC; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ACC; result->src[0] = a; result->src[1] = b; @@ -4634,23 +4854,23 @@ static struct lm_ggml_tensor * lm_ggml_acc_impl( struct lm_ggml_tensor * lm_ggml_acc( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } struct lm_ggml_tensor * lm_ggml_acc_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); } @@ -4658,21 +4878,14 @@ struct lm_ggml_tensor * lm_ggml_acc_inplace( static struct lm_ggml_tensor * lm_ggml_sub_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SUB; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SUB; result->src[0] = a; result->src[1] = b; @@ -4681,15 +4894,15 @@ static struct lm_ggml_tensor * lm_ggml_sub_impl( struct lm_ggml_tensor * lm_ggml_sub( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_sub_impl(ctx, a, b, false); } struct lm_ggml_tensor * lm_ggml_sub_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { return lm_ggml_sub_impl(ctx, a, b, true); } @@ -4697,27 +4910,14 @@ struct lm_ggml_tensor * lm_ggml_sub_inplace( static struct lm_ggml_tensor * lm_ggml_mul_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - bool inplace) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - is_node = true; - } - - if (inplace) { - LM_GGML_ASSERT(!is_node); - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_MUL; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MUL; result->src[0] = a; result->src[1] = b; @@ -4742,25 +4942,14 @@ struct lm_ggml_tensor * lm_ggml_mul_inplace( static struct lm_ggml_tensor * lm_ggml_div_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - bool inplace) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - LM_GGML_ASSERT(!is_node); - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_DIV; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_DIV; result->src[0] = a; result->src[1] = b; @@ -4785,18 +4974,11 @@ struct lm_ggml_tensor * lm_ggml_div_inplace( static struct lm_ggml_tensor * lm_ggml_sqr_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct lm_ggml_tensor * a, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SQR; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SQR; result->src[0] = a; return result; @@ -4818,18 +5000,11 @@ struct lm_ggml_tensor * lm_ggml_sqr_inplace( static struct lm_ggml_tensor * lm_ggml_sqrt_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct lm_ggml_tensor * a, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SQRT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SQRT; result->src[0] = a; return result; @@ -4852,17 +5027,10 @@ struct lm_ggml_tensor * lm_ggml_sqrt_inplace( static struct lm_ggml_tensor * lm_ggml_log_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_LOG; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_LOG; result->src[0] = a; return result; @@ -4880,21 +5048,66 @@ struct lm_ggml_tensor * lm_ggml_log_inplace( return lm_ggml_log_impl(ctx, a, true); } -// lm_ggml_sum +// lm_ggml_sin -struct lm_ggml_tensor * lm_ggml_sum( +static struct lm_ggml_tensor * lm_ggml_sin_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - bool is_node = false; + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - if (a->grad) { - is_node = true; - } + result->op = LM_GGML_OP_SIN; + result->src[0] = a; + + return result; +} + +struct lm_ggml_tensor * lm_ggml_sin( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sin_impl(ctx, a, false); +} + +struct lm_ggml_tensor * lm_ggml_sin_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sin_impl(ctx, a, true); +} + +// lm_ggml_cos + +static struct lm_ggml_tensor * lm_ggml_cos_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + + result->op = LM_GGML_OP_COS; + result->src[0] = a; + + return result; +} + +struct lm_ggml_tensor * lm_ggml_cos( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_cos_impl(ctx, a, false); +} + +struct lm_ggml_tensor * lm_ggml_cos_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_cos_impl(ctx, a, true); +} + +// lm_ggml_sum +struct lm_ggml_tensor * lm_ggml_sum( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); - result->op = LM_GGML_OP_SUM; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SUM; result->src[0] = a; return result; @@ -4904,13 +5117,7 @@ struct lm_ggml_tensor * lm_ggml_sum( struct lm_ggml_tensor * lm_ggml_sum_rows( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - + struct lm_ggml_tensor * a) { int64_t ne[LM_GGML_MAX_DIMS] = { 1 }; for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { ne[i] = a->ne[i]; @@ -4918,8 +5125,7 @@ struct lm_ggml_tensor * lm_ggml_sum_rows( struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); - result->op = LM_GGML_OP_SUM_ROWS; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SUM_ROWS; result->src[0] = a; return result; @@ -4929,19 +5135,11 @@ struct lm_ggml_tensor * lm_ggml_sum_rows( struct lm_ggml_tensor * lm_ggml_mean( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - + struct lm_ggml_tensor * a) { int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - result->op = LM_GGML_OP_MEAN; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MEAN; result->src[0] = a; return result; @@ -4951,20 +5149,30 @@ struct lm_ggml_tensor * lm_ggml_mean( struct lm_ggml_tensor * lm_ggml_argmax( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { + struct lm_ggml_tensor * a) { LM_GGML_ASSERT(lm_ggml_is_matrix(a)); - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); - is_node = true; - } struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]); - result->op = LM_GGML_OP_ARGMAX; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ARGMAX; + result->src[0] = a; + + return result; +} + +// lm_ggml_count_equal + +struct lm_ggml_tensor * lm_ggml_count_equal( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I64, 1); + + result->op = LM_GGML_OP_COUNT_EQUAL; result->src[0] = a; + result->src[1] = b; return result; } @@ -4973,20 +5181,13 @@ struct lm_ggml_tensor * lm_ggml_argmax( struct lm_ggml_tensor * lm_ggml_repeat( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { LM_GGML_ASSERT(lm_ggml_can_repeat(a, b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); - result->op = LM_GGML_OP_REPEAT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_REPEAT; result->src[0] = a; return result; @@ -4996,24 +5197,13 @@ struct lm_ggml_tensor * lm_ggml_repeat( struct lm_ggml_tensor * lm_ggml_repeat_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (lm_ggml_are_same_shape(a, b) && !is_node) { - return a; - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); - result->op = LM_GGML_OP_REPEAT_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_REPEAT_BACK; result->src[0] = a; return result; @@ -5023,9 +5213,9 @@ struct lm_ggml_tensor * lm_ggml_repeat_back( struct lm_ggml_tensor * lm_ggml_concat( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int dim) { + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int dim) { LM_GGML_ASSERT(dim >= 0 && dim < LM_GGML_MAX_DIMS); int64_t ne[LM_GGML_MAX_DIMS]; @@ -5038,18 +5228,11 @@ struct lm_ggml_tensor * lm_ggml_concat( ne[d] = a->ne[d]; } - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); lm_ggml_set_op_params_i32(result, 0, dim); - result->op = LM_GGML_OP_CONCAT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CONCAT; result->src[0] = a; result->src[1] = b; @@ -5158,19 +5341,14 @@ struct lm_ggml_tensor * lm_ggml_relu_inplace( struct lm_ggml_tensor * lm_ggml_leaky_relu( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, float negative_slope, bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct lm_ggml_tensor * a, + float negative_slope, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); - result->op = LM_GGML_OP_LEAKY_RELU; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_LEAKY_RELU; result->src[0] = a; return result; @@ -5238,17 +5416,9 @@ struct lm_ggml_tensor * lm_ggml_silu_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b) { - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: implement backward - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SILU_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SILU_BACK; result->src[0] = a; result->src[1] = b; @@ -5256,6 +5426,7 @@ struct lm_ggml_tensor * lm_ggml_silu_back( } // ggml hardswish + struct lm_ggml_tensor * lm_ggml_hardswish( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { @@ -5263,32 +5434,39 @@ struct lm_ggml_tensor * lm_ggml_hardswish( } // ggml hardsigmoid + struct lm_ggml_tensor * lm_ggml_hardsigmoid( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSIGMOID); } +// ggml exp + +struct lm_ggml_tensor * lm_ggml_exp( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_EXP); +} + +struct lm_ggml_tensor * lm_ggml_exp_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_EXP); +} + // lm_ggml_norm static struct lm_ggml_tensor * lm_ggml_norm_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + float eps, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = LM_GGML_OP_NORM; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_NORM; result->src[0] = a; return result; @@ -5297,14 +5475,14 @@ static struct lm_ggml_tensor * lm_ggml_norm_impl( struct lm_ggml_tensor * lm_ggml_norm( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { + float eps) { return lm_ggml_norm_impl(ctx, a, eps, false); } struct lm_ggml_tensor * lm_ggml_norm_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { + float eps) { return lm_ggml_norm_impl(ctx, a, eps, true); } @@ -5313,20 +5491,13 @@ struct lm_ggml_tensor * lm_ggml_norm_inplace( static struct lm_ggml_tensor * lm_ggml_rms_norm_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + float eps, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = LM_GGML_OP_RMS_NORM; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RMS_NORM; result->src[0] = a; return result; @@ -5335,14 +5506,14 @@ static struct lm_ggml_tensor * lm_ggml_rms_norm_impl( struct lm_ggml_tensor * lm_ggml_rms_norm( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { + float eps) { return lm_ggml_rms_norm_impl(ctx, a, eps, false); } struct lm_ggml_tensor * lm_ggml_rms_norm_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { + float eps) { return lm_ggml_rms_norm_impl(ctx, a, eps, true); } @@ -5352,20 +5523,12 @@ struct lm_ggml_tensor * lm_ggml_rms_norm_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - float eps) { - bool is_node = false; - - if (a->grad) { - // TODO: implement backward - is_node = true; - } - + float eps) { struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = LM_GGML_OP_RMS_NORM_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RMS_NORM_BACK; result->src[0] = a; result->src[1] = b; @@ -5375,43 +5538,35 @@ struct lm_ggml_tensor * lm_ggml_rms_norm_back( // lm_ggml_group_norm static struct lm_ggml_tensor * lm_ggml_group_norm_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_groups, - float eps, - bool inplace) { - - bool is_node = false; - if (!inplace && (a->grad)) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int n_groups, + float eps, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params_i32(result, 0, n_groups); lm_ggml_set_op_params_f32(result, 1, eps); - result->op = LM_GGML_OP_GROUP_NORM; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_GROUP_NORM; result->src[0] = a; return result; } struct lm_ggml_tensor * lm_ggml_group_norm( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_groups, - float eps) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int n_groups, + float eps) { return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, false); } struct lm_ggml_tensor * lm_ggml_group_norm_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_groups, - float eps) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int n_groups, + float eps) { return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, true); } @@ -5424,17 +5579,10 @@ struct lm_ggml_tensor * lm_ggml_mul_mat( LM_GGML_ASSERT(lm_ggml_can_mul_mat(a, b)); LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - result->op = LM_GGML_OP_MUL_MAT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MUL_MAT; result->src[0] = a; result->src[1] = b; @@ -5480,17 +5628,10 @@ struct lm_ggml_tensor * lm_ggml_mul_mat_id( LM_GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat LM_GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast - bool is_node = false; - - if (as->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - result->op = LM_GGML_OP_MUL_MAT_ID; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MUL_MAT_ID; result->src[0] = as; result->src[1] = b; result->src[2] = ids; @@ -5507,18 +5648,11 @@ struct lm_ggml_tensor * lm_ggml_out_prod( LM_GGML_ASSERT(lm_ggml_can_out_prod(a, b)); LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - result->op = LM_GGML_OP_OUT_PROD; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_OUT_PROD; result->src[0] = a; result->src[1] = b; @@ -5531,21 +5665,14 @@ static struct lm_ggml_tensor * lm_ggml_scale_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, float s, - bool inplace) { + bool inplace) { LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, &s, sizeof(s)); - result->op = LM_GGML_OP_SCALE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SCALE; result->src[0] = a; return result; @@ -5553,15 +5680,15 @@ static struct lm_ggml_tensor * lm_ggml_scale_impl( struct lm_ggml_tensor * lm_ggml_scale( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float s) { + struct lm_ggml_tensor * a, + float s) { return lm_ggml_scale_impl(ctx, a, s, false); } struct lm_ggml_tensor * lm_ggml_scale_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float s) { + struct lm_ggml_tensor * a, + float s) { return lm_ggml_scale_impl(ctx, a, s, true); } @@ -5575,23 +5702,17 @@ static struct lm_ggml_tensor * lm_ggml_set_impl( size_t nb2, size_t nb3, size_t offset, - bool inplace) { + bool inplace) { LM_GGML_ASSERT(lm_ggml_nelements(a) >= lm_ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // make a view of the destination struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + LM_GGML_ASSERT(offset < (size_t)(1 << 30)); int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_SET; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SET; result->src[0] = a; result->src[1] = b; @@ -5600,8 +5721,8 @@ static struct lm_ggml_tensor * lm_ggml_set_impl( struct lm_ggml_tensor * lm_ggml_set( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5611,8 +5732,8 @@ struct lm_ggml_tensor * lm_ggml_set( struct lm_ggml_tensor * lm_ggml_set_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5622,24 +5743,24 @@ struct lm_ggml_tensor * lm_ggml_set_inplace( struct lm_ggml_tensor * lm_ggml_set_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t offset) { return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); } struct lm_ggml_tensor * lm_ggml_set_1d_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t offset) { return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); } struct lm_ggml_tensor * lm_ggml_set_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t nb1, size_t offset) { return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); @@ -5647,8 +5768,8 @@ struct lm_ggml_tensor * lm_ggml_set_2d( struct lm_ggml_tensor * lm_ggml_set_2d_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, size_t nb1, size_t offset) { return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); @@ -5662,13 +5783,6 @@ static struct lm_ggml_tensor * lm_ggml_cpy_impl( struct lm_ggml_tensor * b) { LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - // inplace is false and either one have a grad - is_node = true; - } - // make a view of the destination struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, b); if (strlen(b->name) > 0) { @@ -5677,8 +5791,7 @@ static struct lm_ggml_tensor * lm_ggml_cpy_impl( lm_ggml_format_name(result, "%s (copy)", a->name); } - result->op = LM_GGML_OP_CPY; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CPY; result->src[0] = a; result->src[1] = b; @@ -5696,13 +5809,10 @@ struct lm_ggml_tensor * lm_ggml_cast( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, enum lm_ggml_type type) { - bool is_node = false; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); lm_ggml_format_name(result, "%s (copy)", a->name); - result->op = LM_GGML_OP_CPY; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CPY; result->src[0] = a; result->src[1] = result; @@ -5714,17 +5824,10 @@ struct lm_ggml_tensor * lm_ggml_cast( static struct lm_ggml_tensor * lm_ggml_cont_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); lm_ggml_format_name(result, "%s (cont)", a->name); - result->op = LM_GGML_OP_CONT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CONT; result->src[0] = a; return result; @@ -5770,13 +5873,10 @@ struct lm_ggml_tensor * lm_ggml_cont_4d( int64_t ne3) { LM_GGML_ASSERT(lm_ggml_nelements(a) == (ne0*ne1*ne2*ne3)); - bool is_node = false; - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); lm_ggml_format_name(result, "%s (cont)", a->name); - result->op = LM_GGML_OP_CONT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CONT; result->src[0] = a; return result; @@ -5792,22 +5892,10 @@ struct lm_ggml_tensor * lm_ggml_reshape( // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (b->grad) { - // gradient propagation is not supported - //LM_GGML_ABORT("fatal error"); - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, LM_GGML_MAX_DIMS, b->ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_RESHAPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -5820,18 +5908,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_1d( LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[1] = { ne0 }; struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_RESHAPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -5845,18 +5926,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_2d( LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[2] = { ne0, ne1 }; struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_RESHAPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -5871,18 +5945,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_3d( LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[3] = { ne0, ne1, ne2 }; struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_RESHAPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -5898,18 +5965,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_4d( LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2*ne3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_RESHAPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -5921,20 +5981,12 @@ static struct lm_ggml_tensor * lm_ggml_view_impl( int n_dims, const int64_t * ne, size_t offset) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); lm_ggml_format_name(result, "%s (view)", a->name); lm_ggml_set_op_params(result, &offset, sizeof(offset)); - result->op = LM_GGML_OP_VIEW; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_VIEW; result->src[0] = a; return result; @@ -5947,7 +5999,6 @@ struct lm_ggml_tensor * lm_ggml_view_1d( struct lm_ggml_tensor * a, int64_t ne0, size_t offset) { - struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 1, &ne0, offset); return result; @@ -5962,7 +6013,6 @@ struct lm_ggml_tensor * lm_ggml_view_2d( int64_t ne1, size_t nb1, size_t offset) { - const int64_t ne[2] = { ne0, ne1 }; struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 2, ne, offset); @@ -5985,7 +6035,6 @@ struct lm_ggml_tensor * lm_ggml_view_3d( size_t nb1, size_t nb2, size_t offset) { - const int64_t ne[3] = { ne0, ne1, ne2 }; struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 3, ne, offset); @@ -6010,7 +6059,6 @@ struct lm_ggml_tensor * lm_ggml_view_4d( size_t nb2, size_t nb3, size_t offset) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 4, ne, offset); @@ -6043,12 +6091,6 @@ struct lm_ggml_tensor * lm_ggml_permute( LM_GGML_ASSERT(axis1 != axis3); LM_GGML_ASSERT(axis2 != axis3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); lm_ggml_format_name(result, "%s (permuted)", a->name); @@ -6075,8 +6117,7 @@ struct lm_ggml_tensor * lm_ggml_permute( result->nb[2] = nb[2]; result->nb[3] = nb[3]; - result->op = LM_GGML_OP_PERMUTE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_PERMUTE; result->src[0] = a; int32_t params[] = { axis0, axis1, axis2, axis3 }; @@ -6090,12 +6131,6 @@ struct lm_ggml_tensor * lm_ggml_permute( struct lm_ggml_tensor * lm_ggml_transpose( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); lm_ggml_format_name(result, "%s (transposed)", a->name); @@ -6105,8 +6140,7 @@ struct lm_ggml_tensor * lm_ggml_transpose( result->nb[0] = a->nb[1]; result->nb[1] = a->nb[0]; - result->op = LM_GGML_OP_TRANSPOSE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_TRANSPOSE; result->src[0] = a; return result; @@ -6122,12 +6156,6 @@ struct lm_ggml_tensor * lm_ggml_get_rows( LM_GGML_ASSERT(b->ne[3] == 1); LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return enum lm_ggml_type type = LM_GGML_TYPE_F32; if (a->type == LM_GGML_TYPE_I32) { @@ -6135,8 +6163,7 @@ struct lm_ggml_tensor * lm_ggml_get_rows( } struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); - result->op = LM_GGML_OP_GET_ROWS; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_GET_ROWS; result->src[0] = a; result->src[1] = b; @@ -6153,18 +6180,11 @@ struct lm_ggml_tensor * lm_ggml_get_rows_back( LM_GGML_ASSERT(lm_ggml_is_matrix(a) && lm_ggml_is_vector(b) && b->type == LM_GGML_TYPE_I32); LM_GGML_ASSERT(lm_ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return //struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, c->ne[0], c->ne[1]); - result->op = LM_GGML_OP_GET_ROWS_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_GET_ROWS_BACK; result->src[0] = a; result->src[1] = b; @@ -6177,17 +6197,11 @@ struct lm_ggml_tensor * lm_ggml_diag( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { LM_GGML_ASSERT(a->ne[1] == 1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, 4, ne); - result->op = LM_GGML_OP_DIAG; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_DIAG; result->src[0] = a; return result; @@ -6200,19 +6214,12 @@ static struct lm_ggml_tensor * lm_ggml_diag_mask_inf_impl( struct lm_ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_DIAG_MASK_INF; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_DIAG_MASK_INF; result->src[0] = a; return result; @@ -6239,19 +6246,12 @@ static struct lm_ggml_tensor * lm_ggml_diag_mask_zero_impl( struct lm_ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_DIAG_MASK_ZERO; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_DIAG_MASK_ZERO; result->src[0] = a; return result; @@ -6294,19 +6294,12 @@ static struct lm_ggml_tensor * lm_ggml_soft_max_impl( LM_GGML_ASSERT(mask); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); float params[] = { scale, max_bias }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_SOFT_MAX; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SOFT_MAX; result->src[0] = a; result->src[1] = mask; @@ -6341,16 +6334,9 @@ static struct lm_ggml_tensor * lm_ggml_soft_max_back_impl( struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, bool inplace) { - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; // TODO : implement backward pass - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SOFT_MAX_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SOFT_MAX_BACK; result->src[0] = a; result->src[1] = b; @@ -6399,12 +6385,6 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl( LM_GGML_ASSERT(c->ne[0] >= n_dims / 2); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; @@ -6416,8 +6396,7 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl( memcpy(params + 10, &beta_slow, sizeof(float)); lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_ROPE; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ROPE; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -6544,15 +6523,6 @@ struct lm_ggml_tensor * lm_ggml_rope_back( LM_GGML_ASSERT(lm_ggml_is_vector(b)); LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); LM_GGML_ASSERT(a->ne[2] == b->ne[0]); - LM_GGML_ASSERT(c == NULL && "freq factors not implemented yet"); - - LM_GGML_ASSERT((mode & 4) == 0 && "lm_ggml_rope_back() for ChatGLM not implemented yet"); - - bool is_node = false; - - if (a->grad) { - is_node = false; // TODO: implement backward - } struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); @@ -6565,10 +6535,10 @@ struct lm_ggml_tensor * lm_ggml_rope_back( memcpy(params + 10, &beta_slow, sizeof(float)); lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_ROPE_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ROPE_BACK; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } @@ -6580,21 +6550,13 @@ struct lm_ggml_tensor * lm_ggml_clamp( struct lm_ggml_tensor * a, float min, float max) { - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // TODO: when implement backward, fix this: struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); float params[] = { min, max }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_CLAMP; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CLAMP; result->src[0] = a; return result; @@ -6656,13 +6618,6 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( LM_GGML_ASSERT(p0 == 0); LM_GGML_ASSERT(d0 == 1); - bool is_node = false; - - if (a->grad || b->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), a->ne[1], b->ne[2], 1, @@ -6672,8 +6627,7 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( int32_t params[] = { s0, p0, d0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; result->src[0] = a; result->src[1] = b; @@ -6681,17 +6635,17 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( } // lm_ggml_conv_depthwise -struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { +struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), @@ -6711,33 +6665,30 @@ struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( // b: [N, IC, IH, IW] // result: [N, OH, OW, IC*KH*KW] struct lm_ggml_tensor * lm_ggml_im2col( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum lm_ggml_type dst_type) { - + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum lm_ggml_type dst_type) { if(is_2D) { LM_GGML_ASSERT(a->ne[2] == b->ne[2]); } else { LM_GGML_ASSERT(a->ne[1] == b->ne[1]); - } - bool is_node = false; - - if (a->grad || b->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; + LM_GGML_ASSERT(b->ne[3] == 1); } const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + LM_GGML_ASSERT((OW > 0) && "b too small compared to a"); + const int64_t ne[4] = { is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, @@ -6749,8 +6700,30 @@ struct lm_ggml_tensor * lm_ggml_im2col( int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_IM2COL; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_IM2COL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct lm_ggml_tensor * lm_ggml_im2col_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_IM2COL_BACK; result->src[0] = a; result->src[1] = b; @@ -6764,13 +6737,13 @@ struct lm_ggml_tensor * lm_ggml_conv_2d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, @@ -6785,6 +6758,7 @@ struct lm_ggml_tensor * lm_ggml_conv_2d( } // lm_ggml_conv_2d_sk_p0 + struct lm_ggml_tensor * lm_ggml_conv_2d_sk_p0( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -6814,13 +6788,6 @@ struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0( int stride) { LM_GGML_ASSERT(a->ne[3] == b->ne[2]); - bool is_node = false; - - if (a->grad || b->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { lm_ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), lm_ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), @@ -6831,8 +6798,7 @@ struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0( lm_ggml_set_op_params_i32(result, 0, stride); - result->op = LM_GGML_OP_CONV_TRANSPOSE_2D; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CONV_TRANSPOSE_2D; result->src[0] = a; result->src[1] = b; @@ -6854,14 +6820,6 @@ struct lm_ggml_tensor * lm_ggml_pool_1d( int k0, int s0, int p0) { - - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], @@ -6873,8 +6831,7 @@ struct lm_ggml_tensor * lm_ggml_pool_1d( int32_t params[] = { op, k0, s0, p0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_POOL_1D; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_POOL_1D; result->src[0] = a; return result; @@ -6892,105 +6849,103 @@ struct lm_ggml_tensor * lm_ggml_pool_2d( int s1, float p0, float p1) { - - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - struct lm_ggml_tensor * result; - const int64_t ne[3] = { + const int64_t ne[4] = { lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), lm_ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), a->ne[2], + a->ne[3], }; - result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne); + result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_POOL_2D; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_POOL_2D; result->src[0] = a; + return result; } -// lm_ggml_upscale +struct lm_ggml_tensor * lm_ggml_pool_2d_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * af, + enum lm_ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + struct lm_ggml_tensor * result; + result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne); -static struct lm_ggml_tensor * lm_ggml_upscale_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { - bool is_node = false; + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + lm_ggml_set_op_params(result, params, sizeof(params)); - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } + result->op = LM_GGML_OP_POOL_2D_BACK; + result->src[0] = a; + result->src[1] = af; + return result; +} + +// lm_ggml_upscale + +static struct lm_ggml_tensor * lm_ggml_upscale_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { LM_GGML_ASSERT(a->ne[0] <= ne0); LM_GGML_ASSERT(a->ne[1] <= ne1); LM_GGML_ASSERT(a->ne[2] <= ne2); LM_GGML_ASSERT(a->ne[3] <= ne3); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, - ne0, - ne1, - ne2, - ne3 - ); - - result->op = LM_GGML_OP_UPSCALE; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_UPSCALE; result->src[0] = a; return result; } struct lm_ggml_tensor * lm_ggml_upscale( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int scale_factor) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int scale_factor) { return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); } struct lm_ggml_tensor * lm_ggml_upscale_ext( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); } // lm_ggml_pad struct lm_ggml_tensor * lm_ggml_pad( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int p0, int p1, int p2, int p3) { - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, a->ne[0] + p0, a->ne[1] + p1, a->ne[2] + p2, a->ne[3] + p3); - result->op = LM_GGML_OP_PAD; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_PAD; result->src[0] = a; return result; @@ -6999,39 +6954,32 @@ struct lm_ggml_tensor * lm_ggml_pad( // lm_ggml_arange struct lm_ggml_tensor * lm_ggml_arange( - struct lm_ggml_context * ctx, - float start, - float stop, - float step) { - + struct lm_ggml_context * ctx, + float start, + float stop, + float step) { LM_GGML_ASSERT(stop > start); const int64_t steps = (int64_t) ceilf((stop - start) / step); struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, steps); - result->op = LM_GGML_OP_ARANGE; lm_ggml_set_op_params_f32(result, 0, start); lm_ggml_set_op_params_f32(result, 1, stop); lm_ggml_set_op_params_f32(result, 2, step); + result->op = LM_GGML_OP_ARANGE; + return result; } // lm_ggml_timestep_embedding struct lm_ggml_tensor * lm_ggml_timestep_embedding( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * timesteps, - int dim, - int max_period) { - bool is_node = false; - - if (timesteps->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * timesteps, + int dim, + int max_period) { int actual_dim = dim; if (dim % 2 != 0) { actual_dim = dim + 1; @@ -7039,11 +6987,10 @@ struct lm_ggml_tensor * lm_ggml_timestep_embedding( struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, actual_dim, timesteps->ne[0]); - result->op = LM_GGML_OP_TIMESTEP_EMBEDDING; lm_ggml_set_op_params_i32(result, 0, dim); lm_ggml_set_op_params_i32(result, 1, max_period); - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_TIMESTEP_EMBEDDING; result->src[0] = timesteps; return result; @@ -7052,17 +6999,14 @@ struct lm_ggml_tensor * lm_ggml_timestep_embedding( // lm_ggml_argsort struct lm_ggml_tensor * lm_ggml_argsort( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_sort_order order) { - bool is_node = false; - + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + enum lm_ggml_sort_order order) { struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne); lm_ggml_set_op_params_i32(result, 0, (int32_t) order); - result->op = LM_GGML_OP_ARGSORT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ARGSORT; result->src[0] = a; return result; @@ -7115,10 +7059,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext( bool is_node = false; - if (q->grad || k->grad || v->grad) { - is_node = true; - } - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); @@ -7241,21 +7181,14 @@ struct lm_ggml_tensor * lm_ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? + // FIXME: this is always true? LM_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); LM_GGML_ASSERT(sx->ne[1] == d_inner); LM_GGML_ASSERT(n_t >= 0); - bool is_node = false; - - if (sx->grad || c->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_t, n_s); - result->op = LM_GGML_OP_SSM_CONV; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_SSM_CONV; result->src[0] = sx; result->src[1] = c; @@ -7299,18 +7232,10 @@ struct lm_ggml_tensor * lm_ggml_ssm_scan( LM_GGML_ASSERT(B->ne[2] == n_seqs); } - bool is_node = false; - - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - // concatenated y + ssm_states struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, lm_ggml_nelements(x) + lm_ggml_nelements(s)); result->op = LM_GGML_OP_SSM_SCAN; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = dt; @@ -7330,13 +7255,6 @@ struct lm_ggml_tensor * lm_ggml_win_part( LM_GGML_ASSERT(a->ne[3] == 1); LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // padding const int px = (w - a->ne[1]%w)%w; const int py = (w - a->ne[2]%w)%w; @@ -7351,8 +7269,7 @@ struct lm_ggml_tensor * lm_ggml_win_part( int32_t params[] = { npx, npy, w }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_WIN_PART; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_WIN_PART; result->src[0] = a; return result; @@ -7368,21 +7285,13 @@ struct lm_ggml_tensor * lm_ggml_win_unpart( int w) { LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne); int32_t params[] = { w }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_WIN_UNPART; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_WIN_UNPART; result->src[0] = a; return result; @@ -7398,18 +7307,10 @@ struct lm_ggml_tensor * lm_ggml_get_rel_pos( LM_GGML_ASSERT(qh == kh); LM_GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); - bool is_node = false; - - if (a->grad) { - LM_GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F16, 3, ne); - result->op = LM_GGML_OP_GET_REL_POS; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_GET_REL_POS; result->src[0] = a; return result; @@ -7433,17 +7334,10 @@ static struct lm_ggml_tensor * lm_ggml_add_rel_pos_impl( LM_GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); LM_GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); - bool is_node = false; - - if (!inplace && (a->grad || pw->grad || ph->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - result->op = LM_GGML_OP_ADD_REL_POS; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_ADD_REL_POS; result->src[0] = a; result->src[1] = pw; result->src[2] = ph; @@ -7467,27 +7361,65 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace( return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } +// lm_ggml_rwkv_wkv + +struct lm_ggml_tensor * lm_ggml_rwkv_wkv( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * r, + struct lm_ggml_tensor * tf, + struct lm_ggml_tensor * td, + struct lm_ggml_tensor * state) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(k)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(v)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(r)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(tf)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(td)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[2]; + const int64_t n_tokens = k->ne[3]; + const int64_t n_seqs = state->ne[1]; + { + LM_GGML_ASSERT(k->ne[1] == 1); + LM_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); + LM_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); + // TODO: RWKV v4 and v5 + LM_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + result->op = LM_GGML_OP_RWKV_WKV; + result->src[0] = k; + result->src[1] = v; + result->src[2] = r; + result->src[3] = tf; + result->src[4] = td; + result->src[5] = state; + + return result; +} + // lm_ggml_unary static struct lm_ggml_tensor * lm_ggml_unary_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op, - bool inplace) { + struct lm_ggml_tensor * a, + enum lm_ggml_unary_op op, + bool inplace) { LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a)); - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params_i32(result, 0, (int32_t) op); - result->op = LM_GGML_OP_UNARY; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_UNARY; result->src[0] = a; return result; @@ -7496,14 +7428,14 @@ static struct lm_ggml_tensor * lm_ggml_unary_impl( struct lm_ggml_tensor * lm_ggml_unary( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op) { + enum lm_ggml_unary_op op) { return lm_ggml_unary_impl(ctx, a, op, false); } struct lm_ggml_tensor * lm_ggml_unary_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op) { + enum lm_ggml_unary_op op) { return lm_ggml_unary_impl(ctx, a, op, true); } @@ -7512,20 +7444,13 @@ struct lm_ggml_tensor * lm_ggml_unary_inplace( static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + const lm_ggml_unary_op_f32_t fun, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_MAP_UNARY; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_UNARY; result->src[0] = a; return result; @@ -7534,14 +7459,14 @@ static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32( struct lm_ggml_tensor * lm_ggml_map_unary_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun) { + const lm_ggml_unary_op_f32_t fun) { return lm_ggml_map_unary_impl_f32(ctx, a, fun, false); } struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun) { + const lm_ggml_unary_op_f32_t fun) { return lm_ggml_map_unary_impl_f32(ctx, a, fun, true); } @@ -7551,22 +7476,15 @@ static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun, - bool inplace) { + const lm_ggml_binary_op_f32_t fun, + bool inplace) { LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_MAP_BINARY; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_BINARY; result->src[0] = a; result->src[1] = b; @@ -7577,7 +7495,7 @@ struct lm_ggml_tensor * lm_ggml_map_binary_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun) { + const lm_ggml_binary_op_f32_t fun) { return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false); } @@ -7585,7 +7503,7 @@ struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun) { + const lm_ggml_binary_op_f32_t fun) { return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true); } @@ -7595,19 +7513,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, const lm_ggml_custom1_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_MAP_CUSTOM1_F32; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM1_F32; result->src[0] = a; return result; @@ -7634,19 +7545,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32( struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, const lm_ggml_custom2_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_MAP_CUSTOM2_F32; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM2_F32; result->src[0] = a; result->src[1] = b; @@ -7677,19 +7581,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32( struct lm_ggml_tensor * b, struct lm_ggml_tensor * c, const lm_ggml_custom3_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_MAP_CUSTOM3_F32; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM3_F32; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -7717,26 +7614,20 @@ struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32( // lm_ggml_map_custom1 struct lm_ggml_map_custom1_op_params { - lm_ggml_custom1_op_t fun; - int n_tasks; - void * userdata; + lm_ggml_custom1_op_t fun; + int n_tasks; + void * userdata; }; static struct lm_ggml_tensor * lm_ggml_map_custom1_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); struct lm_ggml_map_custom1_op_params params = { @@ -7746,55 +7637,48 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl( }; lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = LM_GGML_OP_MAP_CUSTOM1; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM1; result->src[0] = a; return result; } struct lm_ggml_tensor * lm_ggml_map_custom1( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); } struct lm_ggml_tensor * lm_ggml_map_custom1_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); } // lm_ggml_map_custom2 struct lm_ggml_map_custom2_op_params { - lm_ggml_custom2_op_t fun; - int n_tasks; - void * userdata; + lm_ggml_custom2_op_t fun; + int n_tasks; + void * userdata; }; static struct lm_ggml_tensor * lm_ggml_map_custom2_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); struct lm_ggml_map_custom2_op_params params = { @@ -7804,8 +7688,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl( }; lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = LM_GGML_OP_MAP_CUSTOM2; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM2; result->src[0] = a; result->src[1] = b; @@ -7813,22 +7696,22 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl( } struct lm_ggml_tensor * lm_ggml_map_custom2( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); } struct lm_ggml_tensor * lm_ggml_map_custom2_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); } @@ -7841,22 +7724,16 @@ struct lm_ggml_map_custom3_op_params { }; static struct lm_ggml_tensor * lm_ggml_map_custom3_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); struct lm_ggml_map_custom3_op_params params = { @@ -7866,8 +7743,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl( }; lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = LM_GGML_OP_MAP_CUSTOM3; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_MAP_CUSTOM3; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -7876,44 +7752,38 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl( } struct lm_ggml_tensor * lm_ggml_map_custom3( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); } struct lm_ggml_tensor * lm_ggml_map_custom3_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } // lm_ggml_cross_entropy_loss struct lm_ggml_tensor * lm_ggml_cross_entropy_loss( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); - result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS; result->src[0] = a; result->src[1] = b; @@ -7923,36 +7793,63 @@ struct lm_ggml_tensor * lm_ggml_cross_entropy_loss( // lm_ggml_cross_entropy_loss_back struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c) { + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c) { LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); LM_GGML_ASSERT(lm_ggml_is_scalar(c)); struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK; - result->grad = NULL; + result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +// opt_step_adamw + +struct lm_ggml_tensor * lm_ggml_opt_step_adamw( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * grad, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM); + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad)); + LM_GGML_ASSERT(alpha > 0.0f); + LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); + LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); + LM_GGML_ASSERT(eps >= 0.0f); + LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); + + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); + + const int64_t iter = 1; + memcpy(&result->op_params[0], &iter, sizeof(int64_t)); + lm_ggml_set_op_params_f32(result, 2, alpha); + lm_ggml_set_op_params_f32(result, 3, beta1); + lm_ggml_set_op_params_f32(result, 4, beta2); + lm_ggml_set_op_params_f32(result, 5, eps); + lm_ggml_set_op_params_f32(result, 6, wd); + + result->op = LM_GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; - result->src[1] = b; - result->src[2] = c; + result->src[1] = grad; + result->src[2] = lm_ggml_dup_tensor(ctx, grad); + result->src[3] = lm_ggml_dup_tensor(ctx, grad); return result; } //////////////////////////////////////////////////////////////////////////////// -void lm_ggml_set_param( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * tensor) { - tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM; - - LM_GGML_ASSERT(tensor->grad == NULL); - tensor->grad = lm_ggml_dup_tensor(ctx, tensor); - lm_ggml_format_name(tensor->grad, "%s (grad)", tensor->name); -} - // lm_ggml_compute_forward_dup static void lm_ggml_compute_forward_dup_same_cont( @@ -7965,8 +7862,7 @@ static void lm_ggml_compute_forward_dup_same_cont( LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); LM_GGML_ASSERT(src0->type == dst->type); - const size_t nb00 = src0->nb[0]; - const size_t nb0 = dst->nb[0]; + const size_t nb0 = lm_ggml_type_size(src0->type); const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -7980,8 +7876,8 @@ static void lm_ggml_compute_forward_dup_same_cont( if (ie0 < ie1) { memcpy( ((char *) dst->data + ie0*nb0), - ((char *) src0->data + ie0*nb00), - (ie1 - ie0) * lm_ggml_type_size(src0->type)); + ((char *) src0->data + ie0*nb0), + (ie1 - ie0) * nb0); } } @@ -7998,11 +7894,6 @@ static void lm_ggml_compute_forward_dup_f16( const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) { - lm_ggml_compute_forward_dup_same_cont(params, dst); - return; - } - // parallelize by rows const int nr = ne01; // number of rows per thread @@ -8267,11 +8158,6 @@ static void lm_ggml_compute_forward_dup_bf16( const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) { - lm_ggml_compute_forward_dup_same_cont(params, dst); - return; - } - // parallelize by rows const int nr = ne01; // number of rows per thread @@ -8623,11 +8509,6 @@ static void lm_ggml_compute_forward_dup_f32( const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) { - lm_ggml_compute_forward_dup_same_cont(params, dst); - return; - } - // parallelize by rows const int nr = ne01; // number of rows per thread @@ -8937,13 +8818,13 @@ static void lm_ggml_compute_forward_dup_bytes( LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); LM_GGML_ASSERT(src0->type == dst->type); + LM_GGML_TENSOR_UNARY_OP_LOCALS; + if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) { lm_ggml_compute_forward_dup_same_cont(params, dst); return; } - LM_GGML_TENSOR_UNARY_OP_LOCALS; - const size_t type_size = lm_ggml_type_size(src0->type); const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -9564,6 +9445,8 @@ static void lm_ggml_compute_forward_add( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -9942,6 +9825,8 @@ static void lm_ggml_compute_forward_add1( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -9993,7 +9878,7 @@ static void lm_ggml_compute_forward_acc_f32( ((char *) src0->data), lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); } const int ith = params->ith; @@ -10070,6 +9955,8 @@ static void lm_ggml_compute_forward_acc( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -10098,11 +9985,10 @@ static void lm_ggml_compute_forward_sub_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } + assert(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + const int ith = params->ith; + const int nth = params->nth; const int nr = lm_ggml_nrows(src0); @@ -10111,40 +9997,55 @@ static void lm_ggml_compute_forward_sub_f32( LM_GGML_ASSERT( nb0 == sizeof(float)); LM_GGML_ASSERT(nb00 == sizeof(float)); + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + for (int64_t r = 0; r < nr0; ++r) { #ifdef LM_GGML_USE_ACCELERATE - vDSP_vsub( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); #else - lm_ggml_vec_sub_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + lm_ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif - // } - // } + } } } else { // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; } @@ -10490,9 +10391,9 @@ static void lm_ggml_compute_forward_log( } } -// lm_ggml_compute_forward_sum +// lm_ggml_compute_forward_sin -static void lm_ggml_compute_forward_sum_f32( +static void lm_ggml_compute_forward_sin_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { @@ -10502,8 +10403,95 @@ static void lm_ggml_compute_forward_sum_f32( return; } - assert(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sin_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sin( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sin_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_cos + +static void lm_ggml_compute_forward_cos_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_cos_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_cos( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_cos_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sum + +static void lm_ggml_compute_forward_sum_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } assert(lm_ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); @@ -10777,6 +10765,86 @@ static void lm_ggml_compute_forward_argmax( } } +// lm_ggml_compute_forward_count_equal + +static void lm_ggml_compute_forward_count_equal_i32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); + LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_I64); + + const int64_t nr = lm_ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; + + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; + + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); + + sum_thread += val0 == val1; + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + lm_ggml_barrier(params->threadpool); + + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; + } + *((int64_t *) dst->data) = sum_thread; +} + +static void lm_ggml_compute_forward_count_equal( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_count_equal_i32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + // lm_ggml_compute_forward_repeat static void lm_ggml_compute_forward_repeat_f32( @@ -11744,7 +11812,49 @@ static void lm_ggml_compute_forward_hardsigmoid_f32( } } -static void lm_ggml_compute_forward_hardsigmoid( +static void lm_ggml_compute_forward_hardsigmoid( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_hardsigmoid_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_exp_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_exp( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { @@ -11753,7 +11863,7 @@ static void lm_ggml_compute_forward_hardsigmoid( switch (src0->type) { case LM_GGML_TYPE_F32: { - lm_ggml_compute_forward_hardsigmoid_f32(params, dst); + lm_ggml_compute_forward_exp_f32(params, dst); } break; default: { @@ -12363,10 +12473,10 @@ UseGgmlGemm1:; if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(¶ms->shared->current_chunk, nth); + atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); #if LM_GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { @@ -12474,7 +12584,7 @@ UseGgmlGemm2:; break; } - current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); + current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); } } @@ -12569,7 +12679,7 @@ static void lm_ggml_compute_forward_mul_mat_id( } } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { @@ -12698,6 +12808,10 @@ static void lm_ggml_compute_forward_out_prod_f32( LM_GGML_TENSOR_BINARY_OP_LOCALS + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + const int ith = params->ith; const int nth = params->nth; @@ -12723,7 +12837,7 @@ static void lm_ggml_compute_forward_out_prod_f32( if (ith == 0) { lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); // dst[:,:,:,:] = 0 // for i2,i3: @@ -12841,7 +12955,7 @@ static void lm_ggml_compute_forward_out_prod_q_f32( if (ith == 0) { lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); // parallelize by last three dimensions @@ -12907,6 +13021,8 @@ static void lm_ggml_compute_forward_out_prod( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -13027,7 +13143,7 @@ static void lm_ggml_compute_forward_set_f32( ((char *) src0->data), lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); } const int ith = params->ith; @@ -13095,6 +13211,8 @@ static void lm_ggml_compute_forward_set( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -13208,7 +13326,7 @@ static void lm_ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -13249,7 +13367,7 @@ static void lm_ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); lm_ggml_fp16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -13290,7 +13408,7 @@ static void lm_ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); lm_ggml_bf16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -13331,7 +13449,7 @@ static void lm_ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); lm_ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), @@ -13357,6 +13475,8 @@ static void lm_ggml_compute_forward_get_rows( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -13606,7 +13726,7 @@ static void lm_ggml_compute_forward_diag_mask_f32( ((char *) src0->data), lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); } // TODO: handle transposed/permuted matrices @@ -13946,6 +14066,8 @@ static void lm_ggml_compute_forward_clamp( case LM_GGML_TYPE_Q4_K: case LM_GGML_TYPE_Q5_K: case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: case LM_GGML_TYPE_IQ2_XXS: case LM_GGML_TYPE_IQ2_XS: case LM_GGML_TYPE_IQ3_XXS: @@ -14019,7 +14141,7 @@ static void lm_ggml_rope_cache_init( } } -LM_GGML_CALL void lm_ggml_rope_yarn_corr_dims( +void lm_ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims @@ -14382,7 +14504,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14470,7 +14592,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14525,6 +14647,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d( } } +// lm_ggml_compute_forward_im2col_f32 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14535,7 +14658,6 @@ static void lm_ggml_compute_forward_im2col_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); @@ -14566,7 +14688,6 @@ static void lm_ggml_compute_forward_im2col_f32( int ofs0 = is_2D ? nb13 : nb12; int ofs1 = is_2D ? nb12 : nb11; - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); LM_GGML_ASSERT(nb10 == sizeof(float)); // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] @@ -14602,6 +14723,7 @@ static void lm_ggml_compute_forward_im2col_f32( } +// lm_ggml_compute_forward_im2col_f16 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14697,6 +14819,99 @@ static void lm_ggml_compute_forward_im2col( } } +// lm_ggml_compute_forward_im2col_back_f32 + +static void lm_ggml_compute_forward_im2col_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne3 : ne2; + const int64_t IC = is_2D ? ne2 : ne1; + const int64_t IH = is_2D ? ne1 : 1; + const int64_t IW = ne0; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne12 : 1; + const int64_t OW = ne11; + + int ofs0 = is_2D ? nb3 : nb2; + int ofs1 = is_2D ? nb2 : nb1; + + LM_GGML_ASSERT(nb0 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + for (int64_t iih = 0; iih < IH; iih++) { + for (int64_t iiw = 0; iiw < IW; iiw++) { + + // micro kernel + float grad = 0.0f; + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + // For s0 > 1 some values were skipped over in the forward pass. + // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. + const int64_t tmpw = (iiw + p0 - ikw*d0); + if (tmpw % s0 != 0) { + continue; + } + const int64_t iow = tmpw / s0; + + // Equivalent logic as above except for s1. + int64_t ioh; + if (is_2D) { + const int64_t tmph = iih + p1 - ikh*d1; + + if (tmph % s1 != 0) { + continue; + } + + ioh = tmph / s1; + } else { + ioh = 0; + } + + if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { + continue; + } + + const float * const src_data = (const float *) src1->data + + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + } + } + float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] + dst_data[iih*IW + iiw] = grad; + } + } + } + } + } +} // lm_ggml_compute_forward_conv_transpose_2d @@ -14757,7 +14972,7 @@ static void lm_ggml_compute_forward_conv_transpose_2d( memset(dst->data, 0, lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); const int32_t stride = lm_ggml_get_op_params_i32(dst, 0); @@ -14939,6 +15154,128 @@ static void lm_ggml_compute_forward_pool_2d( } } +// lm_ggml_compute_forward_pool_2d_back + +static void lm_ggml_compute_forward_pool_2d_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src = dst->src[0]; + const struct lm_ggml_tensor * dstf = dst->src[1]; // forward tensor of dst + + assert(dst->type == LM_GGML_TYPE_F32 || dst->type == LM_GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum lm_ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + char * cdata = (char *) dst->data; + const char * cdataf = (const char *) dstf->data; + const char * const data_end = cdata + lm_ggml_nbytes(dst); + + LM_GGML_ASSERT(params->ith == 0); + memset(cdata, 0, lm_ggml_nbytes(dst)); + + const int64_t px = src->ne[0]; + const int64_t py = src->ne[1]; + const int64_t pa = px * py; + + const float * splane = (const float *) src->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + const float * const srow = splane + oy * px; + for (int ox = 0; ox < px; ++ox) { + const float grad0 = srow[ox]; + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + if (op == LM_GGML_OP_POOL_MAX) { + float maxval = -FLT_MAX; + int kxmax = -1; + int kymax = -1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + const float val = dst->type == LM_GGML_TYPE_F32 ? + ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]); + if (val <= maxval) { + continue; + } + + maxval = val; + kxmax = kx; + kymax = ky; + } + } + + if (kxmax == -1 || kymax == -1) { + continue; + } + + void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); + const int j = ix + kxmax; + if (dst->type == LM_GGML_TYPE_F32) { + ((float *) drow)[j] += grad0; + } else { + ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j])); + } + } else if (op == LM_GGML_OP_POOL_AVG) { + const float grad = grad0 / ka; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + if (dst->type == LM_GGML_TYPE_F32) { + ((float *) drow)[j] += grad; + } else { + ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad); + } + } + } + } else { + LM_GGML_ASSERT(false); + } + } + } + + cdata += dst->nb[2]; + cdataf += dst->nb[2]; + splane += pa; + } +} + // lm_ggml_compute_forward_upscale static void lm_ggml_compute_forward_upscale_f32( @@ -15295,6 +15632,9 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16( lm_ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; lm_ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + LM_GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); + LM_GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -15503,7 +15843,7 @@ static void lm_ggml_compute_forward_flash_attn_back_f32( if (ith == 0) { memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); const int64_t elem_q = lm_ggml_nelements(q); const int64_t elem_k = lm_ggml_nelements(k); @@ -15898,8 +16238,8 @@ static void lm_ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} // use the output as the source for the next token-wise iterations if (i2 > 0) { s0 = s; } @@ -16125,6 +16465,10 @@ static void lm_ggml_compute_forward_unary( { lm_ggml_compute_forward_hardsigmoid(params, dst); } break; + case LM_GGML_UNARY_OP_EXP: + { + lm_ggml_compute_forward_exp(params, dst); + } break; default: { LM_GGML_ABORT("fatal error"); @@ -16194,7 +16538,7 @@ static void lm_ggml_compute_forward_add_rel_pos_f32( if (params->ith == 0) { memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst)); } - lm_ggml_barrier(params->shared); + lm_ggml_barrier(params->threadpool); } // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 @@ -16260,6 +16604,96 @@ static void lm_ggml_compute_forward_add_rel_pos( } } +// lm_ggml_compute_forward_rwkv_wkv + +static void lm_ggml_compute_forward_rwkv_wkv_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + const size_t T = dst->src[1]->ne[3]; + const size_t C = dst->ne[0]; + const size_t H = dst->src[1]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + if (params->ith != 0) { + return; + } + + memset(dst_data, 0, T * C * sizeof(float)); + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * r = (float *) dst->src[2]->data; + float * time_faaaa = (float *) dst->src[3]->data; + float * time_decay = (float *) dst->src[4]->data; + + size_t t_stride = H * (C / H); + + size_t h_stride = C / H; + size_t h_stride_2d = (C / H) * (C / H); + + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = (C / H) * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (size_t h = 0; h < H; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (size_t j = 0; j < C / H; j ++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } +} + +static void lm_ggml_compute_forward_rwkv_wkv( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rwkv_wkv_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + // lm_ggml_compute_forward_map_unary static void lm_ggml_compute_forward_map_unary_f32( @@ -16460,74 +16894,67 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); - LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type)); + LM_GGML_ASSERT(src1->nb[0] == lm_ggml_type_size(src1->type)); LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); + LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = lm_ggml_nrows(src0); const int ith = params->ith; const int nth = params->nth; - float * sums = (float *) params->wdata; - - // TODO: handle transposed/permuted matrices - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); + float * sums = (float *) params->wdata; + float * st = ((float *) params->wdata) + nth + ith*nc; + float sum_thread = 0.0f; LM_GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - if (ith == 0) { - memset(sums, 0, sizeof(float) * (nth + nth * nc)); - } - lm_ggml_barrier(params->shared); - - const double eps = 1e-9; - // rows per thread - const int dr = (nr + nth - 1)/nth; + const int64_t dr = (nr + nth - 1)/nth; // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - float * st = ((float *) params->wdata) + nth + ith*nc; + for (int64_t i1 = ir0; i1 < ir1; ++i1) { + const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); + const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { + for (int64_t i = 0; i < nc; ++i) { //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(s0[i])); assert(!isnan(s1[i])); } #endif - // soft_max float max = -INFINITY; lm_ggml_vec_max_f32(nc, &max, s0); - lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, st, s0, max); - assert(sum > 0.0); - sum = (1.0 - eps) / sum; + const lm_ggml_float sum_softmax = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum_softmax >= 0.0); - // avoid log(0) by rescaling from [0..1] to [eps..1] - lm_ggml_vec_scale_f32(nc, st, sum); - lm_ggml_vec_add1_f32(nc, st, st, eps); - lm_ggml_vec_log_f32(nc, st, st); + lm_ggml_vec_add1_f32(nc, st, st, -sum_softmax); lm_ggml_vec_mul_f32(nc, st, st, s1); - float st_sum = 0; - lm_ggml_vec_sum_f32(nc, &st_sum, st); - sums[ith] += st_sum; + float sum_st = 0.0f; + lm_ggml_vec_sum_f32(nc, &sum_st, st); + sum_thread += sum_st; #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { + for (int64_t i = 0; i < nc; ++i) { assert(!isnan(st[i])); assert(!isinf(st[i])); } #endif } - lm_ggml_barrier(params->shared); + sums[ith] = sum_thread; + lm_ggml_barrier(params->threadpool); if (ith == 0) { float * dp = (float *) dst->data; @@ -16573,8 +17000,6 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - const double eps = 1e-9; - // TODO: handle transposed/permuted matrices const int64_t nc = src0->ne[0]; const int64_t nr = lm_ggml_nrows(src0); @@ -16586,44 +17011,131 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - float * d = (float *) opt0->data; + const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; for (int64_t i1 = ir0; i1 < ir1; i1++) { float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + // soft_max + float max = -INFINITY; + lm_ggml_vec_max_f32(nc, &max, s0); + lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max); + assert(sum > 0.0); + lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum); + + // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + lm_ggml_vec_sub_f32(nc, ds0, ds0, s1); + lm_ggml_vec_scale_f32(nc, ds0, d_by_nr); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void lm_ggml_compute_forward_cross_entropy_loss_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_opt_step_adamw_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src0_grad = dst->src[1]; + const struct lm_ggml_tensor * src0_grad_m = dst->src[2]; + const struct lm_ggml_tensor * src0_grad_v = dst->src[3]; + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + /* const float gnorm = 1.0f; */ + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + const float alpha = lm_ggml_get_op_params_f32(dst, 2); + const float beta1 = lm_ggml_get_op_params_f32(dst, 3); + const float beta2 = lm_ggml_get_op_params_f32(dst, 4); + const float eps = lm_ggml_get_op_params_f32(dst, 5); + const float wd = lm_ggml_get_op_params_f32(dst, 6); + + const float beta1h = alpha/(1.0f - powf(beta1, iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; - // soft_max - float max = -INFINITY; - lm_ggml_vec_max_f32(nc, &max, s0); - lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max); - assert(sum > 0.0); - sum = (1.0 - eps) / sum; + float * w = (float *) ((char *) src0->data + offset); // weight + const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad + float * m = (float *) ((char *) src0_grad_m->data + offset); + float * v = (float *) ((char *) src0_grad_v->data + offset); - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - lm_ggml_vec_scale_f32(nc, ds0, sum); - lm_ggml_vec_add1_f32(nc, ds0, ds0, eps); - lm_ggml_vec_sub_f32(nc, ds0, ds0, s1); - lm_ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); + for (int i00 = 0; i00 < ne00; ++i00) { + m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1); + v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2); -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(ds0[i])); - assert(!isinf(ds0[i])); + const float mh = m[i00]*beta1h; + const float vh = sqrtf(v[i00]*beta2h) + eps; + + // The weight decay is applied independently of the Adam momenta m and v. + // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. + // See: https://arxiv.org/pdf/1711.05101v3.pdf + w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; } -#endif } + + lm_ggml_barrier(params->threadpool); + if (ith != 0) { + return; + } + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } -static void lm_ggml_compute_forward_cross_entropy_loss_back( +static void lm_ggml_compute_forward_opt_step_adamw( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { @@ -16632,7 +17144,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back( switch (src0->type) { case LM_GGML_TYPE_F32: { - lm_ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); + lm_ggml_compute_forward_opt_step_adamw_f32(params, dst); } break; default: { @@ -16640,7 +17152,6 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back( } } } - ///////////////////////////////// static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) { @@ -16691,6 +17202,14 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_log(params, tensor); } break; + case LM_GGML_OP_SIN: + { + lm_ggml_compute_forward_sin(params, tensor); + } break; + case LM_GGML_OP_COS: + { + lm_ggml_compute_forward_cos(params, tensor); + } break; case LM_GGML_OP_SUM: { lm_ggml_compute_forward_sum(params, tensor); @@ -16707,6 +17226,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_argmax(params, tensor); } break; + case LM_GGML_OP_COUNT_EQUAL: + { + lm_ggml_compute_forward_count_equal(params, tensor); + } break; case LM_GGML_OP_REPEAT: { lm_ggml_compute_forward_repeat(params, tensor); @@ -16831,6 +17354,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_im2col(params, tensor); } break; + case LM_GGML_OP_IM2COL_BACK: + { + lm_ggml_compute_forward_im2col_back_f32(params, tensor); + } break; case LM_GGML_OP_CONV_TRANSPOSE_2D: { lm_ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -16843,6 +17370,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_pool_2d(params, tensor); } break; + case LM_GGML_OP_POOL_2D_BACK: + { + lm_ggml_compute_forward_pool_2d_back(params, tensor); + } break; case LM_GGML_OP_UPSCALE: { lm_ggml_compute_forward_upscale(params, tensor); @@ -16906,6 +17437,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_add_rel_pos(params, tensor); } break; + case LM_GGML_OP_RWKV_WKV: + { + lm_ggml_compute_forward_rwkv_wkv(params, tensor); + } break; case LM_GGML_OP_MAP_UNARY: { lm_ggml_unary_op_f32_t fun; @@ -16966,6 +17501,11 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru lm_ggml_compute_forward_cross_entropy_loss_back(params, tensor); } break; + case LM_GGML_OP_OPT_STEP_ADAMW: + { + lm_ggml_compute_forward_opt_step_adamw(params, tensor); + } + break; case LM_GGML_OP_NONE: { // nop @@ -17120,7 +17660,7 @@ void lm_ggml_build_backward_gradient_checkpointing( struct lm_ggml_tensor * * checkpoints, int n_checkpoints) { lm_ggml_graph_cpy(gf, gb_tmp); - lm_ggml_build_backward_expand(ctx, gf, gb_tmp, true); + lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false); if (n_checkpoints <= 0) { lm_ggml_graph_cpy(gb_tmp, gb); @@ -17158,42 +17698,93 @@ void lm_ggml_build_backward_gradient_checkpointing( lm_ggml_hash_map_free(replacements); } -// functions to change gradients considering the case that input a might be initial gradient with zero value - -static struct lm_ggml_tensor * lm_ggml_add_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) { +// utility functions to change gradients +// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator +// else if a is in zero_table, replace a +// else, just add/subtract/etc. the gradients + +static struct lm_ggml_tensor * lm_ggml_add_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (lm_ggml_hash_contains(zero_table, a)) { return b; - } else { - return lm_ggml_add_impl(ctx, a, b, false); } + return lm_ggml_add_impl(ctx, a, b, false); } -static struct lm_ggml_tensor * lm_ggml_acc_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct lm_ggml_hash_set * zero_table) { +static struct lm_ggml_tensor * lm_ggml_acc_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (lm_ggml_hash_contains(zero_table, a)) { - struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); + struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } else { - return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } + return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } -static struct lm_ggml_tensor * lm_ggml_add1_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) { +static struct lm_ggml_tensor * lm_ggml_add1_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (lm_ggml_hash_contains(zero_table, a)) { return lm_ggml_repeat(ctx, b, a); - } else { - return lm_ggml_add1_impl(ctx, a, b, false); } + return lm_ggml_add1_impl(ctx, a, b, false); } -static struct lm_ggml_tensor * lm_ggml_sub_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) { +static struct lm_ggml_tensor * lm_ggml_sub_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (lm_ggml_hash_contains(zero_table, a)) { return lm_ggml_neg(ctx, b); - } else { - return lm_ggml_sub_impl(ctx, a, b, false); } + return lm_ggml_sub_impl(ctx, a, b, false); } -static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table) { +static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) { struct lm_ggml_tensor * src0 = tensor->src[0]; struct lm_ggml_tensor * src1 = tensor->src[1]; struct lm_ggml_tensor * src2 = tensor->src[2]; @@ -17202,34 +17793,38 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm case LM_GGML_OP_DUP: { if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case LM_GGML_OP_ADD: { if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { - src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + if (lm_ggml_are_same_shape(src0, src1)) { + src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); + } else { + src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); + } } } break; case LM_GGML_OP_ADD1: { if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_ACC: { if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -17251,16 +17846,16 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SUB: { if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { - src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); + src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); } } break; case LM_GGML_OP_MUL: @@ -17270,14 +17865,14 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_mul(ctx, src1, tensor->grad), - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_mul(ctx, src0, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_DIV: @@ -17287,7 +17882,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_div(ctx, tensor->grad, src1), - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = @@ -17296,7 +17891,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_mul(ctx, tensor->grad, lm_ggml_div(ctx, tensor, src1)), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SQR: @@ -17308,7 +17903,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_scale(ctx, lm_ggml_mul(ctx, src0, tensor->grad), 2.0f), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SQRT: @@ -17322,7 +17917,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm tensor->grad, tensor), 0.5f), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_LOG: @@ -17334,7 +17929,31 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_div(ctx, tensor->grad, src0), - zero_table); + zero_table, acc_table); + } + } break; + case LM_GGML_OP_SIN: + { + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + tensor->grad, + lm_ggml_cos(ctx, src0)), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_COS: + { + if (src0->grad) { + src0->grad = + lm_ggml_sub_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + tensor->grad, + lm_ggml_sin(ctx, src0)), + zero_table, acc_table); } } break; case LM_GGML_OP_SUM: @@ -17344,7 +17963,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_add1_or_set(ctx, src0->grad, tensor->grad, - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SUM_ROWS: @@ -17356,11 +17975,12 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_MEAN: case LM_GGML_OP_ARGMAX: + case LM_GGML_OP_COUNT_EQUAL: { LM_GGML_ABORT("fatal error"); // TODO: implement } @@ -17371,7 +17991,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_REPEAT_BACK: @@ -17381,7 +18001,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_CONCAT: @@ -17406,7 +18026,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_RMS_NORM_BACK: @@ -17454,7 +18074,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_add_or_set(ctx, src0->grad, // [n,m,q1,r1] s1_tg, // [n,m,q1,r1] - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = @@ -17472,7 +18092,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0, // [n,m,q1,r1] lm_ggml_transpose(ctx, // [p,m,qq,rr] tensor->grad)), // [m,p,qq,rr] - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_MUL_MAT_ID: @@ -17494,7 +18114,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SET: @@ -17509,14 +18129,10 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm if (src0->grad || src1->grad) { LM_GGML_ASSERT(src0->type == tensor->type); LM_GGML_ASSERT(tensor->grad->type == tensor->type); - LM_GGML_ASSERT(tensor->grad->type == src1->grad->type); + LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); tensor_grad_view = lm_ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], + tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], nb1, nb2, nb3, offset); } @@ -17527,7 +18143,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm tensor->grad, lm_ggml_neg(ctx, tensor_grad_view), nb1, nb2, nb3, offset, false), - zero_table); + zero_table, acc_table); } if (src1->grad) { @@ -17537,7 +18153,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_CPY: @@ -17548,7 +18164,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm // tensor = src0 * 1 + src1 * 0 if (src0->grad) { // dsrc0 = dtensor * 1 - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { // dsrc1 = dtensor * 0 -> noop @@ -17560,7 +18176,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm if (src0->grad) { LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad)); LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad)); - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case LM_GGML_OP_RESHAPE: @@ -17574,7 +18190,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm ? tensor->grad : lm_ggml_cont(ctx, tensor->grad), src0->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_VIEW: @@ -17585,9 +18201,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm memcpy(&offset, tensor->op_params, sizeof(offset)); - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; if (src0->type != src0->grad->type) { // gradient is typically F32, but src0 could be other type @@ -17603,7 +18219,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm nb3 = (nb3 / n0) * ng; } - src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); + src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); } } break; case LM_GGML_OP_PERMUTE: @@ -17628,7 +18244,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm axes_backward[1], axes_backward[2], axes_backward[3]), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_TRANSPOSE: @@ -17638,7 +18254,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_transpose(ctx, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_GET_ROWS: @@ -17650,7 +18266,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm // last lm_ggml_get_rows_back argument src0->grad is only // necessary to setup correct output shape lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table); + zero_table, acc_table); } if (src1->grad) { // noop @@ -17674,7 +18290,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm /* lm_ggml_diag_mask_inf_impl() shouldn't be here */ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_DIAG_MASK_ZERO: @@ -17685,7 +18301,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_SOFT_MAX: @@ -17695,9 +18311,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table); + zero_table, acc_table); } - + LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); } break; case LM_GGML_OP_SOFT_MAX_BACK: { @@ -17736,8 +18352,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm attn_factor, beta_fast, beta_slow), - zero_table); + zero_table, acc_table); } + LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); } break; case LM_GGML_OP_ROPE_BACK: { @@ -17772,7 +18389,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm beta_fast, beta_slow, false), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_CLAMP: @@ -17784,6 +18401,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm LM_GGML_ABORT("fatal error"); // TODO: not implemented } case LM_GGML_OP_IM2COL: + { + if (src1->grad) { + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5); + const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1; + + src1->grad = lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_IM2COL_BACK: { LM_GGML_ABORT("fatal error"); // TODO: not implemented } @@ -17796,6 +18430,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm LM_GGML_ABORT("fatal error"); // TODO: not implemented } case LM_GGML_OP_POOL_2D: + { + if (src0->grad) { + const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6); + + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_POOL_2D_BACK: { LM_GGML_ABORT("fatal error"); // TODO: not implemented } @@ -17825,6 +18476,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm } case LM_GGML_OP_FLASH_ATTN_EXT: { + LM_GGML_ABORT("FA backward pass not adapted after rework"); struct lm_ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { int32_t t = lm_ggml_get_op_params_i32(tensor, 0); @@ -17857,7 +18509,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, grad_q, - zero_table); + zero_table, acc_table); } if (src1->grad) { struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k); @@ -17865,7 +18517,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src1->grad = lm_ggml_add_or_set(ctx, src1->grad, grad_k, - zero_table); + zero_table, acc_table); } if (src2->grad) { struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v); @@ -17873,7 +18525,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src2->grad = lm_ggml_add_or_set(ctx, src2->grad, grad_v, - zero_table); + zero_table, acc_table); } } break; case LM_GGML_OP_FLASH_ATTN_BACK: @@ -17899,7 +18551,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_mul(ctx, lm_ggml_sgn(ctx, src0), tensor->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_UNARY_OP_SGN: @@ -17911,7 +18563,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm case LM_GGML_UNARY_OP_NEG: { if (src0->grad) { - src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case LM_GGML_UNARY_OP_STEP: @@ -17936,7 +18588,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm lm_ggml_mul(ctx, lm_ggml_step(ctx, src0), tensor->grad), - zero_table); + zero_table, acc_table); } } break; case LM_GGML_UNARY_OP_SIGMOID: @@ -17958,7 +18610,16 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0->grad = lm_ggml_add_or_set(ctx, src0->grad, lm_ggml_silu_back(ctx, src0, tensor->grad), - zero_table); + zero_table, acc_table); + } + } break; + case LM_GGML_UNARY_OP_EXP: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, tensor, tensor->grad), + zero_table, acc_table); } } break; default: @@ -17967,6 +18628,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm } break; case LM_GGML_OP_GET_REL_POS: case LM_GGML_OP_ADD_REL_POS: + case LM_GGML_OP_RWKV_WKV: case LM_GGML_OP_MAP_UNARY: case LM_GGML_OP_MAP_BINARY: case LM_GGML_OP_MAP_CUSTOM1_F32: @@ -17987,13 +18649,18 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src0, src1, tensor->grad), - zero_table); + zero_table, acc_table); } + LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); } break; case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: { LM_GGML_ABORT("fatal error"); // not supported } + case LM_GGML_OP_OPT_STEP_ADAMW: + { + LM_GGML_ABORT("fatal error"); // not supported + } case LM_GGML_OP_NONE: { // nop @@ -18035,7 +18702,7 @@ static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml } } - if (node->op == LM_GGML_OP_NONE && node->grad == NULL) { + if (node->op == LM_GGML_OP_NONE && !(node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { // reached a leaf node, not part of the gradient graph (e.g. a constant) LM_GGML_ASSERT(cgraph->n_leafs < cgraph->size); @@ -18053,9 +18720,6 @@ static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml } cgraph->nodes[cgraph->n_nodes] = node; - if (cgraph->grads) { - cgraph->grads[cgraph->n_nodes] = node->grad; - } cgraph->n_nodes++; } } @@ -18083,36 +18747,93 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml lm_ggml_build_forward_impl(cgraph, tensor, true); } -void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep) { +void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) { LM_GGML_ASSERT(gf->n_nodes > 0); + LM_GGML_ASSERT(gf->grads); + + for (int i = 0; i < gf->n_nodes; ++i) { + struct lm_ggml_tensor * node = gf->nodes[i]; + + if (node->type == LM_GGML_TYPE_I32) { + continue; + } - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; + bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; + bool ignore_src[LM_GGML_MAX_SRC] = {false}; + switch (node->op) { + // gradients in node->src[0] for one reason or another have no effect on output gradients + case LM_GGML_OP_IM2COL: // only used for its shape + case LM_GGML_OP_IM2COL_BACK: // same as IM2COL + ignore_src[0] = true; + break; + case LM_GGML_OP_UNARY: { + const enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(node); + // SGN and STEP unary ops are piecewise constant + if (uop == LM_GGML_UNARY_OP_SGN || uop == LM_GGML_UNARY_OP_STEP) { + ignore_src[0] = true; + } + } break; - if (node->grad) { - node->grad = lm_ggml_dup_tensor(ctx, node); - gf->grads[i] = node->grad; + // gradients in node->src[1] for one reason or another have no effect on output gradients + case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant + case LM_GGML_OP_GET_ROWS: // row indices not differentiable + case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS + case LM_GGML_OP_ROPE: // positions not differentiable + ignore_src[1] = true; + break; + + default: + break; + } + for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { + if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + continue; } + LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16); + needs_grad = true; + break; + } + if (!needs_grad) { + continue; } + + // inplace operations are currently not supported + LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW || + node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE); + + // create a new tensor with the same type and shape as the node and set it as grad + node->grad = lm_ggml_dup_tensor(ctx, node); } - // remember original gradients which start with zero values + // keep tables of original gradients for replacement/accumulation logic struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size); + struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size); for (int i = 0; i < gf->n_nodes; i++) { - if (gf->grads[i]) { - lm_ggml_hash_insert(&zero_table, gf->grads[i]); + struct lm_ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + { + const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + } + + // only gradients of trainable parameters should be accumulated + if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { + const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + } } } for (int i = gf->n_nodes - 1; i >= 0; i--) { struct lm_ggml_tensor * node = gf->nodes[i]; - // inplace operations to add gradients are not created by lm_ggml_compute_backward + // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations if (node->grad) { - lm_ggml_compute_backward(ctx, node, &zero_table); + lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table); } } @@ -18126,8 +18847,30 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_ } lm_ggml_hash_set_free(&zero_table); + lm_ggml_hash_set_free(&acc_table); +} + +void lm_ggml_build_opt_adamw( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * gf, + struct lm_ggml_cgraph * gb, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + for (int i = 0; i < gf->n_nodes; i++) { + struct lm_ggml_tensor * node = gf->nodes[i]; + + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); + lm_ggml_build_forward_expand(gb, opt_step); + } + } } + static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { void * ptr = *p; ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align); @@ -18238,7 +18981,8 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst) } for (size_t i = 0; i < src->visited_hash_set.size; ++i) { - if (src->visited_hash_set.keys[i]) { + // copy all hashset keys (tensors) that are in use + if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) { lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } @@ -18254,10 +18998,28 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) { LM_GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * grad = cgraph->grads[i]; + struct lm_ggml_tensor * node = cgraph->nodes[i]; + + // initial gradients of loss should be 1, 0 otherwise + if (node->grad) { + if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) { + LM_GGML_ASSERT(node->grad->buffer); + LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_is_scalar(node)); + + const float onef = 1.0f; + lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad)); + } else { + lm_ggml_set_zero(node->grad); + } + } - if (grad) { - lm_ggml_set_zero(grad); + LM_GGML_ASSERT(node); + if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { + // set iteration to 1 and clear momenta + lm_ggml_set_op_params_i32(node, 0, 1); + lm_ggml_set_zero(node->src[2]); + lm_ggml_set_zero(node->src[3]); } } } @@ -18268,64 +19030,33 @@ void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) { lm_ggml_hash_set_reset(&cgraph->visited_hash_set); } -// -// thread data -// -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// - -#ifdef __APPLE__ - -//#include -// -//typedef os_unfair_lock lm_ggml_lock_t; -// -//#define lm_ggml_lock_init(x) UNUSED(x) -//#define lm_ggml_lock_destroy(x) UNUSED(x) -//#define lm_ggml_lock_lock os_unfair_lock_lock -//#define lm_ggml_lock_unlock os_unfair_lock_unlock -// -//#define LM_GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT - -typedef int lm_ggml_lock_t; - -#define lm_ggml_lock_init(x) UNUSED(x) -#define lm_ggml_lock_destroy(x) UNUSED(x) -#define lm_ggml_lock_lock(x) UNUSED(x) -#define lm_ggml_lock_unlock(x) UNUSED(x) - -#define LM_GGML_LOCK_INITIALIZER 0 - -#define lm_ggml_thread_create pthread_create -#define lm_ggml_thread_join pthread_join - -#else - -//typedef pthread_spinlock_t lm_ggml_lock_t; - -//#define lm_ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) -//#define lm_ggml_lock_destroy pthread_spin_destroy -//#define lm_ggml_lock_lock pthread_spin_lock -//#define lm_ggml_lock_unlock pthread_spin_unlock +int lm_ggml_graph_size(struct lm_ggml_cgraph * cgraph) { + return cgraph->size; +} -typedef int lm_ggml_lock_t; +struct lm_ggml_tensor * lm_ggml_graph_node(struct lm_ggml_cgraph * cgraph, int i) { + if (i < 0) { + LM_GGML_ASSERT(cgraph->n_nodes + i >= 0); + return cgraph->nodes[cgraph->n_nodes + i]; + } -#define lm_ggml_lock_init(x) UNUSED(x) -#define lm_ggml_lock_destroy(x) UNUSED(x) -#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) -#define lm_ggml_lock_lock(x) _mm_pause() -#else -#define lm_ggml_lock_lock(x) UNUSED(x) -#endif -#define lm_ggml_lock_unlock(x) UNUSED(x) + LM_GGML_ASSERT(i < cgraph->n_nodes); + return cgraph->nodes[i]; +} -#define LM_GGML_LOCK_INITIALIZER 0 +struct lm_ggml_tensor ** lm_ggml_graph_nodes(struct lm_ggml_cgraph * cgraph) { + return cgraph->nodes; +} -#define lm_ggml_thread_create pthread_create -#define lm_ggml_thread_join pthread_join +int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph) { + return cgraph->n_nodes; +} -#endif +void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(cgraph->size > cgraph->n_nodes); + cgraph->nodes[cgraph->n_nodes] = tensor; + cgraph->n_nodes++; +} // Android's libc implementation "bionic" does not support setting affinity #if defined(__gnu_linux__) @@ -18424,10 +19155,19 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { case LM_GGML_OP_SQR: case LM_GGML_OP_SQRT: case LM_GGML_OP_LOG: + case LM_GGML_OP_SIN: + case LM_GGML_OP_COS: case LM_GGML_OP_SUM: case LM_GGML_OP_SUM_ROWS: case LM_GGML_OP_MEAN: case LM_GGML_OP_ARGMAX: + { + n_tasks = 1; + } break; + case LM_GGML_OP_COUNT_EQUAL: + { + n_tasks = n_threads; + } break; case LM_GGML_OP_REPEAT: case LM_GGML_OP_REPEAT_BACK: case LM_GGML_OP_LEAKY_RELU: @@ -18446,6 +19186,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { case LM_GGML_UNARY_OP_SIGMOID: case LM_GGML_UNARY_OP_HARDSWISH: case LM_GGML_UNARY_OP_HARDSIGMOID: + case LM_GGML_UNARY_OP_EXP: { n_tasks = 1; } break; @@ -18510,6 +19251,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0])); } break; case LM_GGML_OP_IM2COL: + case LM_GGML_OP_IM2COL_BACK: case LM_GGML_OP_CONV_TRANSPOSE_1D: case LM_GGML_OP_CONV_TRANSPOSE_2D: { @@ -18517,6 +19259,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { } break; case LM_GGML_OP_POOL_1D: case LM_GGML_OP_POOL_2D: + case LM_GGML_OP_POOL_2D_BACK: { n_tasks = 1; } break; @@ -18535,6 +19278,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { case LM_GGML_OP_WIN_PART: case LM_GGML_OP_WIN_UNPART: case LM_GGML_OP_GET_REL_POS: + case LM_GGML_OP_RWKV_WKV: case LM_GGML_OP_MAP_UNARY: case LM_GGML_OP_MAP_BINARY: case LM_GGML_OP_MAP_CUSTOM1_F32: @@ -18575,6 +19319,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { } break; case LM_GGML_OP_CROSS_ENTROPY_LOSS: case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case LM_GGML_OP_OPT_STEP_ADAMW: { n_tasks = n_threads; } break; @@ -18598,14 +19343,288 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { } } - assert(n_tasks > 0); + assert(n_tasks > 0); + + return n_tasks; +} + +static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data); + +#if defined(_WIN32) +#include "windows.h" + +// TODO: support > 64 CPUs +bool lm_ggml_thread_apply_affinity(bool * mask) { + HANDLE h = GetCurrentThread(); + uint64_t bitmask = 0ULL; + + assert(LM_GGML_MAX_N_THREADS >= 64); + + for (int32_t i = 0; i < 8; i++) { + int32_t idx = i * 8; + uint8_t val = 0; + val |= mask[idx + 0] << 0; + val |= mask[idx + 1] << 1; + val |= mask[idx + 2] << 2; + val |= mask[idx + 3] << 3; + val |= mask[idx + 4] << 4; + val |= mask[idx + 5] << 5; + val |= mask[idx + 6] << 6; + val |= mask[idx + 7] << 7; + bitmask |= (uint64_t)val << idx; + } + + for (int32_t i = 64; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); + break; + } + } + + DWORD_PTR m = (DWORD_PTR)bitmask; + + m = SetThreadAffinityMask(h, m); + + return m != 0; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. + // This is up to the applications. + DWORD p = THREAD_PRIORITY_NORMAL; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; + case LM_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; + case LM_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; + case LM_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + if (!SetThreadPriority(GetCurrentThread(), p)) { + fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#elif defined(__APPLE__) +#include +#include + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + // Not supported on Apple platforms + UNUSED(mask); + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#elif defined(__gnu_linux__) +// TODO: this may not work on BSD, to be verified + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + cpu_set_t cpuset; + int err; + + CPU_ZERO(&cpuset); + + for (uint32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + LM_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); + CPU_SET(i, &cpuset); + } + } + +#ifdef __ANDROID__ + err = sched_setaffinity(0, sizeof(cpuset), &cpuset); + if (err < 0) { + err = errno; + } +#else + err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); +#endif + if (err != 0) { + fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); + return false; + } + + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#else // unsupported platforms + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + UNUSED(mask); + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + UNUSED(prio); + return true; +} + +#endif + +static bool lm_ggml_thread_cpumask_is_valid(const bool * mask) { + for (int i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { return true; } + } + return false; +} + +static void lm_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { + if (!strict) { + memcpy(local_mask, global_mask, LM_GGML_MAX_N_THREADS); + return; + } else { + memset(local_mask, 0, LM_GGML_MAX_N_THREADS); + int32_t base_idx = *iter; + for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + int32_t idx = base_idx + i; + if (idx >= LM_GGML_MAX_N_THREADS) { + // Just a cheaper modulo + idx -= LM_GGML_MAX_N_THREADS; + } + if (global_mask[idx]) { + local_mask[idx] = 1; + *iter = idx + 1; + return; + } + } + } +} + +void lm_ggml_threadpool_free(struct lm_ggml_threadpool* threadpool) { + if (!threadpool) return; - return n_tasks; + const int n_threads = threadpool->n_threads_max; + +#ifndef LM_GGML_USE_OPENMP + struct lm_ggml_compute_state* workers = threadpool->workers; + + lm_ggml_mutex_lock(&threadpool->mutex); + + threadpool->stop = true; + threadpool->pause = false; + + lm_ggml_cond_broadcast(&threadpool->cond); + lm_ggml_mutex_unlock(&threadpool->mutex); + + for (int j = 1; j < n_threads; j++) { + int32_t rc = lm_ggml_thread_join(workers[j].thrd, NULL); + LM_GGML_ASSERT(rc == LM_GGML_EXIT_SUCCESS || rc == LM_GGML_EXIT_ABORTED); + UNUSED(rc); + } + + lm_ggml_mutex_destroy(&threadpool->mutex); + lm_ggml_cond_destroy(&threadpool->cond); +#endif // LM_GGML_USE_OPENMP + + const size_t workers_size = sizeof(struct lm_ggml_compute_state) * n_threads; + lm_ggml_aligned_free(threadpool->workers, workers_size); + lm_ggml_aligned_free(threadpool, sizeof(struct lm_ggml_threadpool)); +} + +#ifndef LM_GGML_USE_OPENMP +// pause/resume must be called under mutex +static void lm_ggml_threadpool_pause_locked(struct lm_ggml_threadpool * threadpool) { + LM_GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + lm_ggml_cond_broadcast(&threadpool->cond); +} + +static void lm_ggml_threadpool_resume_locked(struct lm_ggml_threadpool * threadpool) { + LM_GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + lm_ggml_cond_broadcast(&threadpool->cond); } +#endif + +void lm_ggml_threadpool_pause(struct lm_ggml_threadpool * threadpool) { +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_lock(&threadpool->mutex); + if (!threadpool->pause) { + lm_ggml_threadpool_pause_locked(threadpool); + } + lm_ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) { +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_lock(&threadpool->mutex); + if (threadpool->pause) { + lm_ggml_threadpool_resume_locked(threadpool); + } + lm_ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +struct lm_ggml_cplan lm_ggml_graph_plan( + const struct lm_ggml_cgraph * cgraph, + int n_threads, + struct lm_ggml_threadpool * threadpool) { -struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, int n_threads) { + if (threadpool == NULL) { + LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + } if (n_threads <= 0) { - n_threads = LM_GGML_DEFAULT_N_THREADS; + n_threads = threadpool ? threadpool->n_threads_max : LM_GGML_DEFAULT_N_THREADS; } size_t work_size = 0; @@ -18649,6 +19668,10 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; } } break; + case LM_GGML_OP_COUNT_EQUAL: + { + cur = lm_ggml_type_size(node->type)*n_tasks; + } break; case LM_GGML_OP_MUL_MAT: { const enum lm_ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; @@ -18761,49 +19784,290 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in } if (work_size > 0) { - work_size += CACHE_LINE_SIZE*(n_threads - 1); + work_size += CACHE_LINE_SIZE*(n_threads); } - cplan.n_threads = MIN(max_tasks, n_threads); - cplan.work_size = work_size; - cplan.work_data = NULL; + cplan.threadpool = threadpool; + cplan.n_threads = MIN(max_tasks, n_threads); + cplan.work_size = work_size; + cplan.work_data = NULL; return cplan; } static thread_ret_t lm_ggml_graph_compute_thread(void * data) { struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; + struct lm_ggml_threadpool * tp = state->threadpool; - const struct lm_ggml_cgraph * cgraph = state->shared->cgraph; - const struct lm_ggml_cplan * cplan = state->shared->cplan; + const struct lm_ggml_cgraph * cgraph = tp->cgraph; + const struct lm_ggml_cplan * cplan = tp->cplan; set_numa_thread_affinity(state->ith); struct lm_ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ state->shared->n_threads, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.shared=*/ state->shared, + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool=*/ tp, }; - for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) { struct lm_ggml_tensor * node = cgraph->nodes[node_n]; lm_ggml_compute_forward(¶ms, node); - if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->ec = LM_GGML_STATUS_ABORTED; + if (state->ith == 0 && cplan->abort_callback && + cplan->abort_callback(cplan->abort_callback_data)) { + tp->abort = true; + tp->ec = LM_GGML_STATUS_ABORTED; + } + + lm_ggml_barrier(state->threadpool); + } + + return 0; +} + +#ifndef LM_GGML_USE_OPENMP + +// check if thread is active +static inline bool lm_ggml_graph_compute_thread_active(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); + return (state->ith < n_threads); +} + +// check if thread is ready to proceed (exit from polling or sleeping) +static inline bool lm_ggml_graph_compute_thread_ready(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + if (state->pending || threadpool->stop || threadpool->pause) { return true; } + + // check for new graph/work + int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + if (new_graph != state->last_graph) { + state->pending = lm_ggml_graph_compute_thread_active(state); + state->last_graph = new_graph; + } + + return state->pending; +} + +// sync thread state after polling +static inline void lm_ggml_graph_compute_thread_sync(struct lm_ggml_compute_state * state) { + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef LM_GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif + UNUSED(state); +} + +static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + // Skip polling for unused threads + if (!lm_ggml_graph_compute_thread_active(state)) { + return state->pending; + } + + // This seems to make 0 ... 100 a decent range for polling level across modern processors. + // Perhaps, we can adjust it dynamically based on load and things. + const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; + + for (uint64_t i=0; !lm_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) { + // No new work. Keep polling. + lm_ggml_thread_cpu_relax(); + } + + return state->pending; +} + +static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + if (lm_ggml_graph_compute_poll_for_work(state)) { + lm_ggml_graph_compute_thread_sync(state); + return state->pending; + } + + lm_ggml_mutex_lock_shared(&threadpool->mutex); + while (!lm_ggml_graph_compute_thread_ready(state)) { + // No new work. Wait for the signal. + LM_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith); + lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + lm_ggml_mutex_unlock_shared(&threadpool->mutex); + + return state->pending; +} + +static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) { + struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; + struct lm_ggml_threadpool * threadpool = state->threadpool; + + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(state->cpumask)) { + lm_ggml_thread_apply_affinity(state->cpumask); + } + + while (true) { + // Check if we need to sleep + while (threadpool->pause) { + LM_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); + lm_ggml_mutex_lock_shared(&threadpool->mutex); + if (threadpool->pause) { + lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + LM_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); + lm_ggml_mutex_unlock_shared(&threadpool->mutex); } - lm_ggml_barrier(state->shared); + // This needs to be checked for after the cond_wait + if (threadpool->stop) break; - if (state->shared->ec != LM_GGML_STATUS_SUCCESS) { - break; + // Check if there is new work + // The main thread is the only one that can dispatch new work + + lm_ggml_graph_compute_check_for_work(state); + if (state->pending) { + state->pending = false; + + lm_ggml_graph_compute_thread(state); } } - return 0; + return (thread_ret_t) 0; +} + +// Start processing new graph +static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool, int n_threads) +{ + // Always take the mutex here because the worker threads are doing hybrid poll/wait + + lm_ggml_mutex_lock(&threadpool->mutex); + + LM_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); + + // Update the number of active threads + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + + // Indicate the graph is ready to be processed + // We need the full seq-cst fence here because of the polling threads (used in thread_sync) + atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); + + if (threadpool->pause) { + // Update main thread prio and affinity to match the threadpool settings + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + + // resume does cond broadcast + lm_ggml_threadpool_resume_locked(threadpool); + } else { + lm_ggml_cond_broadcast(&threadpool->cond); + } + + lm_ggml_mutex_unlock(&threadpool->mutex); +} + +#endif // LM_GGML_USE_OPENMP + +void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) +} + +struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) { + struct lm_ggml_threadpool_params p; + lm_ggml_threadpool_params_init(&p, n_threads); + return p; +} + +bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0; +} + +static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl( + struct lm_ggml_threadpool_params * tpp, + struct lm_ggml_cgraph * cgraph, + struct lm_ggml_cplan * cplan) { + + struct lm_ggml_threadpool * threadpool = + lm_ggml_aligned_malloc(sizeof(struct lm_ggml_threadpool)); + { + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->n_graph = 0; + threadpool->n_barrier = 0; + threadpool->n_barrier_passed = 0; + threadpool->current_chunk = 0; + threadpool->stop = false; + threadpool->pause = tpp->paused; + threadpool->abort = false; + threadpool->workers = NULL; + threadpool->n_threads_max = tpp->n_threads; + threadpool->n_threads_cur = tpp->n_threads; + threadpool->poll = tpp->poll; + threadpool->prio = tpp->prio; + threadpool->ec = LM_GGML_STATUS_SUCCESS; + } + + // Allocate and init workers state + const size_t workers_size = sizeof(struct lm_ggml_compute_state) * tpp->n_threads; + struct lm_ggml_compute_state * workers = lm_ggml_aligned_malloc(workers_size); + + memset(workers, 0, workers_size); + for (int j = 0; j < tpp->n_threads; j++) { + workers[j].threadpool = threadpool; + workers[j].ith = j; + } + + threadpool->workers = workers; + +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_init(&threadpool->mutex); + lm_ggml_cond_init(&threadpool->cond); + + // Spin the threads for all workers, and update CPU placements. + // Place the main thread last (towards the higher numbered CPU cores). + + int32_t cpumask_iter = 0; + + for (int j = 1; j < tpp->n_threads; j++) { + lm_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); + + int32_t rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_secondary_thread, &workers[j]); + LM_GGML_ASSERT(rc == 0); + } + + lm_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); + + if (!threadpool->pause) { + // Update main thread prio and affinity at the start, otherwise we'll do it in resume + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + } +#endif // LM_GGML_USE_OPENMP + + return threadpool; +} + +struct lm_ggml_threadpool * lm_ggml_threadpool_new(struct lm_ggml_threadpool_params * tpp) { + return lm_ggml_threadpool_new_impl(tpp, NULL, NULL); } enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) { @@ -18811,19 +20075,26 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct LM_GGML_ASSERT(cplan->n_threads > 0); LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); - int n_threads = cplan->n_threads; - - struct lm_ggml_compute_state_shared state_shared = { - /*.cgraph =*/ cgraph, - /*.cgraph_plan =*/ cplan, - /*.n_threads =*/ n_threads, - /*.n_barrier =*/ 0, - /*.n_barrier_passed =*/ 0, - /*.abort_callback =*/ NULL, - /*.abort_callback_data =*/ NULL, - /*.current_chunk =*/ 0, - /*.ec =*/ LM_GGML_STATUS_SUCCESS, - }; + int n_threads = cplan->n_threads; + struct lm_ggml_threadpool * threadpool = cplan->threadpool; + + bool disposable_threadpool = false; + + if (threadpool == NULL) { + LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + disposable_threadpool = true; + + struct lm_ggml_threadpool_params ttp = lm_ggml_threadpool_params_default(n_threads); + threadpool = lm_ggml_threadpool_new_impl(&ttp, cgraph, cplan); + } else { + // Reset some of the parameters that need resetting + // No worker threads should be accessing the parameters below at this stage + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->current_chunk = 0; + threadpool->abort = false; + threadpool->ec = LM_GGML_STATUS_SUCCESS; + } #ifdef LM_GGML_USE_OPENMP if (n_threads > 1) { @@ -18833,63 +20104,42 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct { // update the number of threads from the actual number of threads that we got from OpenMP n_threads = omp_get_num_threads(); - state_shared.n_threads = n_threads; + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); } - struct lm_ggml_compute_state worker = { - .thrd = 0, - .ith = omp_get_thread_num(), - .shared = &state_shared, - }; - lm_ggml_graph_compute_thread(&worker); + lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); } } else { - struct lm_ggml_compute_state worker = { - .thrd = 0, - .ith = 0, - .shared = &state_shared, - }; - lm_ggml_graph_compute_thread(&worker); + atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); + lm_ggml_graph_compute_thread(&threadpool->workers[0]); } #else - struct lm_ggml_compute_state * workers = alloca(sizeof(struct lm_ggml_compute_state)*n_threads); - - for (int j = 0; j < n_threads; ++j) { - workers[j] = (struct lm_ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - }; - } - - // create thread pool - for (int j = 1; j < n_threads; ++j) { - const int rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_thread, &workers[j]); - LM_GGML_ASSERT(rc == 0); - UNUSED(rc); + if (n_threads > threadpool->n_threads_max) { + LM_GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); + n_threads = threadpool->n_threads_max; } - // this is a work thread too - lm_ggml_graph_compute_thread(&workers[0]); + // Kick all threads to start the new graph + lm_ggml_graph_compute_kickoff(threadpool, n_threads); - // join or kill thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; j++) { - const int rc = lm_ggml_thread_join(workers[j].thrd, NULL); - LM_GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } + // This is a work thread too + lm_ggml_graph_compute_thread(&threadpool->workers[0]); #endif // don't leave affinity set on the main thread clear_numa_thread_affinity(); - return state_shared.ec; + enum lm_ggml_status ret = threadpool->ec; + + if (disposable_threadpool) { + lm_ggml_threadpool_free(threadpool); + } + + return ret; } enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) { - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads); + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads, NULL); struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); @@ -18951,7 +20201,6 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna uint64_t size_eval = 0; // compute size of intermediate results - // TODO: does not take into account scratch buffers !!!! for (int i = 0; i < cgraph->n_nodes; ++i) { size_eval += lm_ggml_nbytes_pad(cgraph->nodes[i]); } @@ -19030,9 +20279,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19062,9 +20313,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19123,6 +20376,14 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna } } } + + // dump the data + // TODO: pad this to 32 byte boundary + if ((flags & LM_GGML_TENSOR_FLAG_PARAM)) { + const size_t size = lm_ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } } } @@ -19236,10 +20497,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_leafs; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); int64_t ne[LM_GGML_MAX_DIMS]; size_t nb[LM_GGML_MAX_DIMS]; @@ -19257,20 +20520,19 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne); - tensor->op = (enum lm_ggml_op) op; + tensor->op = (enum lm_ggml_op) op; + tensor->flags = flags; memcpy(tensor->name, ptr, LM_GGML_MAX_NAME); ptr += LM_GGML_MAX_NAME; memcpy(tensor->op_params, ptr, LM_GGML_MAX_OP_PARAMS); ptr += LM_GGML_MAX_OP_PARAMS; - tensor->data = (void *) ptr; - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j]; } - result->leafs[i] = tensor; + tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor); - ptr += lm_ggml_nbytes(tensor); + result->leafs[i] = tensor; fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); } @@ -19282,10 +20544,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_nodes; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); enum lm_ggml_op eop = (enum lm_ggml_op) op; @@ -19375,6 +20639,11 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ result->nodes[i] = tensor; + // TODO tensor data is be duplicated due to lm_ggml_new_tensor call above + if (flags & LM_GGML_TENSOR_FLAG_PARAM) { + tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor); + } + fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); } } @@ -19384,30 +20653,30 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ } void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { - LM_GGML_PRINT("=== GRAPH ===\n"); + LM_GGML_LOG_INFO("=== GRAPH ===\n"); - LM_GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + LM_GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes); for (int i = 0; i < cgraph->n_nodes; i++) { struct lm_ggml_tensor * node = cgraph->nodes[i]; - LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", + LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); } - LM_GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); for (int i = 0; i < cgraph->n_leafs; i++) { struct lm_ggml_tensor * node = cgraph->leafs[i]; - LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", + LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", i, node->ne[0], node->ne[1], lm_ggml_op_name(node->op), lm_ggml_get_name(node)); } - LM_GGML_PRINT("========================================\n"); + LM_GGML_LOG_INFO("========================================\n"); } // check if node is part of the graph @@ -19578,7 +20847,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg fclose(fp); - LM_GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); + LM_GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); } //////////////////////////////////////////////////////////////////////////////// @@ -19643,6 +20912,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam( lm_ggml_opt_callback callback, void * callback_data) { LM_GGML_ASSERT(lm_ggml_is_scalar(f)); + LM_GGML_ASSERT(f->type == LM_GGML_TYPE_F32); // these will store the parameters we want to optimize struct lm_ggml_tensor * ps[LM_GGML_MAX_PARAMS]; @@ -19684,7 +20954,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam( float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads); + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL); struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; @@ -20031,7 +21301,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_lbfgs( opt->iter = iter; } - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads); + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL); struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; @@ -20449,7 +21719,7 @@ enum lm_ggml_opt_result lm_ggml_opt_resume( lm_ggml_build_forward_expand(gf, f); struct lm_ggml_cgraph * gb = lm_ggml_graph_dup(ctx, gf); - lm_ggml_build_backward_expand(ctx, gf, gb, true); + lm_ggml_build_backward_expand(ctx, gf, gb, false); return lm_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } @@ -20463,6 +21733,8 @@ enum lm_ggml_opt_result lm_ggml_opt_resume_g( lm_ggml_opt_callback callback, void * callback_data) { + LM_GGML_ASSERT(f->grad && "lm_ggml_set_param must be called for at least one ancestor"); + // build forward + backward compute graphs enum lm_ggml_opt_result result = LM_GGML_OPT_RESULT_OK; @@ -20500,6 +21772,17 @@ void lm_ggml_set_output(struct lm_ggml_tensor * tensor) { tensor->flags |= LM_GGML_TENSOR_FLAG_OUTPUT; } +void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) { + LM_GGML_UNUSED(ctx); // TODO: remove this parameter + tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM; +} + +void lm_ggml_set_loss(struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(lm_ggml_is_scalar(tensor)); + LM_GGML_ASSERT(tensor->type == LM_GGML_TYPE_F32); + tensor->flags |= LM_GGML_TENSOR_FLAG_LOSS; +} + //////////////////////////////////////////////////////////////////////////////// void lm_ggml_quantize_init(enum lm_ggml_type type) { @@ -20574,6 +21857,8 @@ size_t lm_ggml_quantize_chunk( case LM_GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case LM_GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case LM_GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -20726,18 +22011,46 @@ static size_t lm_gguf_type_size(enum lm_gguf_type type) { return LM_GGUF_TYPE_SIZE[type]; } -static void lm_gguf_tensor_info_sanitize(struct lm_gguf_tensor_info * info) { - LM_GGML_ASSERT(info->n_dims <= LM_GGML_MAX_DIMS); - LM_GGML_ASSERT(0 <= info->type && info->type < LM_GGML_TYPE_COUNT); +static bool lm_gguf_tensor_info_sanitize(struct lm_gguf_tensor_info * info) { + if (info->n_dims > LM_GGML_MAX_DIMS) { + fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims); + return false; + } + + if (info->type < 0 || info->type >= LM_GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type); + return false; + } + + if (strlen(info->name.data) >= LM_GGML_MAX_NAME) { + fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data); + return false; + } for (uint32_t i = 0; i < info->n_dims; ++i) { - LM_GGML_ASSERT(info->ne[i] > 0); + if (info->ne[i] <= 0) { + fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]); + return false; + } } // prevent overflow for total number of elements - LM_GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]); - LM_GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]); - LM_GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]); + if (INT64_MAX/info->ne[1] <= info->ne[0]) { + fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]); + return false; + } + + if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) { + fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]); + return false; + } + + if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) { + fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]); + return false; + } + + return true; } static bool lm_gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { @@ -20760,7 +22073,11 @@ static bool lm_gguf_fread_str(FILE * file, struct lm_gguf_str * p, size_t * offs return false; } - p->data = LM_GGML_CALLOC(p->n + 1, 1); + p->data = calloc(p->n + 1, 1); + if (!p->data) { + fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n); + return false; + } ok = ok && lm_gguf_fread_el(file, p->data, p->n, offset); @@ -20794,7 +22111,11 @@ static void lm_gguf_free_kv(struct lm_gguf_kv * kv) { } struct lm_gguf_context * lm_gguf_init_empty(void) { - struct lm_gguf_context * ctx = LM_GGML_CALLOC(1, sizeof(struct lm_gguf_context)); + struct lm_gguf_context * ctx = calloc(1, sizeof(struct lm_gguf_context)); + if (!ctx) { + fprintf(stderr, "%s: failed to allocate memory for context\n", __func__); + return NULL; + } memcpy(ctx->header.magic, LM_GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = LM_GGUF_VERSION; @@ -20840,7 +22161,12 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg bool ok = true; - struct lm_gguf_context * ctx = LM_GGML_CALLOC(1, sizeof(struct lm_gguf_context)); + struct lm_gguf_context * ctx = calloc(1, sizeof(struct lm_gguf_context)); + if (!ctx) { + fprintf(stderr, "%s: failed to allocate memory for context\n", __func__); + fclose(file); + return NULL; + } // read the header { @@ -20879,9 +22205,13 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg { const uint64_t n_kv = ctx->header.n_kv; - // header.n_kv will hold the actual value of pairs that were successfully read in the loop below - ctx->header.n_kv = 0; - ctx->kv = LM_GGML_CALLOC(n_kv, sizeof(struct lm_gguf_kv)); + ctx->kv = calloc(n_kv, sizeof(struct lm_gguf_kv)); + if (!ctx->kv) { + fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__); + fclose(file); + lm_gguf_free(ctx); + return NULL; + } for (uint64_t i = 0; i < n_kv; ++i) { struct lm_gguf_kv * kv = &ctx->kv[i]; @@ -20932,7 +22262,13 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg return NULL; } - kv->value.arr.data = LM_GGML_CALLOC(kv->value.arr.n, lm_gguf_type_size(kv->value.arr.type)); + kv->value.arr.data = calloc(kv->value.arr.n, lm_gguf_type_size(kv->value.arr.type)); + if (!kv->value.arr.data) { + fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); + fclose(file); + lm_gguf_free(ctx); + return NULL; + } ok = ok && lm_gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * lm_gguf_type_size(kv->value.arr.type), &offset); } break; @@ -20946,24 +22282,36 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg return NULL; } - kv->value.arr.data = LM_GGML_CALLOC(kv->value.arr.n, sizeof(struct lm_gguf_str)); + kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct lm_gguf_str)); + if (!kv->value.arr.data) { + fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); + fclose(file); + lm_gguf_free(ctx); + return NULL; + } for (uint64_t j = 0; j < kv->value.arr.n; ++j) { ok = ok && lm_gguf_fread_str(file, &((struct lm_gguf_str *) kv->value.arr.data)[j], &offset); } } break; case LM_GGUF_TYPE_ARRAY: - default: LM_GGML_ABORT("invalid type"); + default: + { + fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type); + ok = false; + } break; } } break; - default: LM_GGML_ABORT("invalid type"); + default: + { + fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type); + ok = false; + } break; } if (!ok) { break; } - - ctx->header.n_kv++; } if (!ok) { @@ -20976,7 +22324,13 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg // read the tensor infos if (ctx->header.n_tensors > 0) { - ctx->infos = LM_GGML_CALLOC(ctx->header.n_tensors, sizeof(struct lm_gguf_tensor_info)); + ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct lm_gguf_tensor_info)); + if (!ctx->infos) { + fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__); + fclose(file); + lm_gguf_free(ctx); + return NULL; + } for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { struct lm_gguf_tensor_info * info = &ctx->infos[i]; @@ -20997,8 +22351,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg ok = ok && lm_gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && lm_gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); - // TODO: return an error instead of crashing with LM_GGML_ASSERT - lm_gguf_tensor_info_sanitize(info); + ok = ok && lm_gguf_tensor_info_sanitize(info); // make sure there is no duplicated tensor names for (uint64_t j = 0; j < i && ok; ++j) { @@ -21550,6 +22903,7 @@ void lm_gguf_set_kv(struct lm_gguf_context * ctx, struct lm_gguf_context * src) void lm_gguf_add_tensor( struct lm_gguf_context * ctx, const struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(tensor); if (lm_gguf_find_tensor(ctx, tensor->name) != -1) { LM_GGML_ABORT("duplicated tensor name"); } @@ -21877,6 +23231,14 @@ int lm_ggml_cpu_has_avx512_bf16(void) { #endif } +int lm_ggml_cpu_has_amx_int8(void) { +#if defined(__AMX_INT8__) + return 1; +#else + return 0; +#endif +} + int lm_ggml_cpu_has_fma(void) { #if defined(__FMA__) return 1; @@ -21886,16 +23248,16 @@ int lm_ggml_cpu_has_fma(void) { } int lm_ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_neon; #else return 0; #endif } int lm_ggml_cpu_has_sve(void) { -#if defined(__ARM_FEATURE_SVE) - return 1; +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_sve; #else return 0; #endif @@ -21909,6 +23271,14 @@ int lm_ggml_cpu_has_arm_fma(void) { #endif } +int lm_ggml_cpu_has_riscv_v(void) { +#if defined(__riscv_v_intrinsic) + return 1; +#else + return 0; +#endif +} + int lm_ggml_cpu_has_metal(void) { #if defined(LM_GGML_USE_METAL) return 1; @@ -22034,11 +23404,23 @@ int lm_ggml_cpu_has_vsx(void) { } int lm_ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_FEATURE_MATMUL_INT8) - return 1; +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_i8mm; +#else + return 0; +#endif +} + +int lm_ggml_cpu_get_sve_cnt(void) { +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.sve_cnt; #else return 0; #endif } +void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) { + g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default; + g_logger_state.log_callback_user_data = user_data; +} //////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/ggml.h b/cpp/ggml.h index 46342f5..e89b832 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -187,16 +187,6 @@ # define LM_GGML_API #endif -#ifdef LM_GGML_MULTIPLATFORM -# if defined(_WIN32) -# define LM_GGML_CALL -# else -# define LM_GGML_CALL __attribute__((__ms_abi__)) -# endif -#else -# define LM_GGML_CALL -#endif - // TODO: support for clang #ifdef __GNUC__ # define LM_GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) @@ -220,21 +210,24 @@ #include #define LM_GGML_FILE_MAGIC 0x67676d6c // "ggml" -#define LM_GGML_FILE_VERSION 1 +#define LM_GGML_FILE_VERSION 2 #define LM_GGML_QNT_VERSION 2 // bump this on quantization format changes #define LM_GGML_QNT_VERSION_FACTOR 1000 // do not change this #define LM_GGML_MAX_DIMS 4 #define LM_GGML_MAX_PARAMS 2048 -#define LM_GGML_MAX_CONTEXTS 64 #define LM_GGML_MAX_SRC 10 +#define LM_GGML_MAX_N_THREADS 512 +#define LM_GGML_MAX_OP_PARAMS 64 + #ifndef LM_GGML_MAX_NAME -#define LM_GGML_MAX_NAME 64 +# define LM_GGML_MAX_NAME 64 #endif -#define LM_GGML_MAX_OP_PARAMS 64 + #define LM_GGML_DEFAULT_N_THREADS 4 #define LM_GGML_DEFAULT_GRAPH_SIZE 2048 + #if UINTPTR_MAX == 0xFFFFFFFF #define LM_GGML_MEM_ALIGN 4 #else @@ -257,21 +250,21 @@ #define LM_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #ifndef NDEBUG -#define LM_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) +# define LM_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) #elif defined(__GNUC__) -#define LM_GGML_UNREACHABLE() __builtin_unreachable() +# define LM_GGML_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) -#define LM_GGML_UNREACHABLE() __assume(0) +# define LM_GGML_UNREACHABLE() __assume(0) #else -#define LM_GGML_UNREACHABLE() ((void) 0) +# define LM_GGML_UNREACHABLE() ((void) 0) #endif #ifdef __cplusplus -#define LM_GGML_NORETURN [[noreturn]] +# define LM_GGML_NORETURN [[noreturn]] #elif defined(_MSC_VER) -#define LM_GGML_NORETURN __declspec(noreturn) +# define LM_GGML_NORETURN __declspec(noreturn) #else -#define LM_GGML_NORETURN _Noreturn +# define LM_GGML_NORETURN _Noreturn #endif #define LM_GGML_ABORT(...) lm_ggml_abort(__FILE__, __LINE__, __VA_ARGS__) @@ -336,7 +329,7 @@ extern "C" { }; // get lm_ggml_status name string - LM_GGML_API LM_GGML_CALL const char * lm_ggml_status_to_string(enum lm_ggml_status status); + LM_GGML_API const char * lm_ggml_status_to_string(enum lm_ggml_status status); // ieee 754-2008 half-precision float16 // todo: make this not an integral type @@ -356,6 +349,7 @@ extern "C" { struct lm_ggml_object; struct lm_ggml_context; + struct lm_ggml_cgraph; // NOTE: always add types at the end of the enum to keep backward compatibility enum lm_ggml_type { @@ -393,6 +387,8 @@ extern "C" { LM_GGML_TYPE_Q4_0_4_4 = 31, LM_GGML_TYPE_Q4_0_4_8 = 32, LM_GGML_TYPE_Q4_0_8_8 = 33, + LM_GGML_TYPE_TQ1_0 = 34, + LM_GGML_TYPE_TQ2_0 = 35, LM_GGML_TYPE_COUNT, }; @@ -453,10 +449,13 @@ extern "C" { LM_GGML_OP_SQR, LM_GGML_OP_SQRT, LM_GGML_OP_LOG, + LM_GGML_OP_SIN, + LM_GGML_OP_COS, LM_GGML_OP_SUM, LM_GGML_OP_SUM_ROWS, LM_GGML_OP_MEAN, LM_GGML_OP_ARGMAX, + LM_GGML_OP_COUNT_EQUAL, LM_GGML_OP_REPEAT, LM_GGML_OP_REPEAT_BACK, LM_GGML_OP_CONCAT, @@ -490,9 +489,11 @@ extern "C" { LM_GGML_OP_CLAMP, LM_GGML_OP_CONV_TRANSPOSE_1D, LM_GGML_OP_IM2COL, + LM_GGML_OP_IM2COL_BACK, LM_GGML_OP_CONV_TRANSPOSE_2D, LM_GGML_OP_POOL_1D, LM_GGML_OP_POOL_2D, + LM_GGML_OP_POOL_2D_BACK, LM_GGML_OP_UPSCALE, // nearest interpolate LM_GGML_OP_PAD, LM_GGML_OP_ARANGE, @@ -508,6 +509,7 @@ extern "C" { LM_GGML_OP_WIN_UNPART, LM_GGML_OP_GET_REL_POS, LM_GGML_OP_ADD_REL_POS, + LM_GGML_OP_RWKV_WKV, LM_GGML_OP_UNARY, @@ -524,6 +526,7 @@ extern "C" { LM_GGML_OP_CROSS_ENTROPY_LOSS, LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK, + LM_GGML_OP_OPT_STEP_ADAMW, LM_GGML_OP_COUNT, }; @@ -542,6 +545,7 @@ extern "C" { LM_GGML_UNARY_OP_SILU, LM_GGML_UNARY_OP_HARDSWISH, LM_GGML_UNARY_OP_HARDSIGMOID, + LM_GGML_UNARY_OP_EXP, LM_GGML_UNARY_OP_COUNT, }; @@ -553,35 +557,25 @@ extern "C" { }; enum lm_ggml_log_level { - LM_GGML_LOG_LEVEL_ERROR = 2, + LM_GGML_LOG_LEVEL_NONE = 0, + LM_GGML_LOG_LEVEL_DEBUG = 1, + LM_GGML_LOG_LEVEL_INFO = 2, LM_GGML_LOG_LEVEL_WARN = 3, - LM_GGML_LOG_LEVEL_INFO = 4, - LM_GGML_LOG_LEVEL_DEBUG = 5 + LM_GGML_LOG_LEVEL_ERROR = 4, + LM_GGML_LOG_LEVEL_CONT = 5, // continue previous log }; + // this tensor... enum lm_ggml_tensor_flag { - LM_GGML_TENSOR_FLAG_INPUT = 1, - LM_GGML_TENSOR_FLAG_OUTPUT = 2, - LM_GGML_TENSOR_FLAG_PARAM = 4, - }; - - // ggml object - struct lm_ggml_object { - size_t offs; - size_t size; - - struct lm_ggml_object * next; - - enum lm_ggml_object_type type; - - char padding[4]; + LM_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + LM_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + LM_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + LM_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; - static const size_t LM_GGML_OBJECT_SIZE = sizeof(struct lm_ggml_object); - // n-dimensional tensor struct lm_ggml_tensor { - enum lm_ggml_type type; + enum lm_ggml_type type; LM_GGML_DEPRECATED(enum lm_ggml_backend_type backend, "use the buffer type to find the storage location of the tensor"); @@ -624,6 +618,29 @@ extern "C" { // If it returns true, the computation is aborted typedef bool (*lm_ggml_abort_callback)(void * data); + // Scheduling priorities + enum lm_ggml_sched_priority { + LM_GGML_SCHED_PRIO_NORMAL, + LM_GGML_SCHED_PRIO_MEDIUM, + LM_GGML_SCHED_PRIO_HIGH, + LM_GGML_SCHED_PRIO_REALTIME + }; + + // Threadpool params + // Use lm_ggml_threadpool_params_default() or lm_ggml_threadpool_params_init() to populate the defaults + struct lm_ggml_threadpool_params { + bool cpumask[LM_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum lm_ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct lm_ggml_threadpool; // forward declaration, see ggml.c + + typedef struct lm_ggml_threadpool * lm_ggml_threadpool_t; + // the compute plan that needs to be prepared for lm_ggml_graph_compute() // since https://github.com/ggerganov/ggml/issues/287 struct lm_ggml_cplan { @@ -631,48 +648,13 @@ extern "C" { uint8_t * work_data; // work buffer, to be allocated by caller before calling to `lm_ggml_graph_compute()` int n_threads; + struct lm_ggml_threadpool * threadpool; // abort lm_ggml_graph_compute when true lm_ggml_abort_callback abort_callback; void * abort_callback_data; }; - enum lm_ggml_cgraph_eval_order { - LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, - LM_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, - LM_GGML_CGRAPH_EVAL_ORDER_COUNT - }; - - typedef uint32_t lm_ggml_bitset_t; - - struct lm_ggml_hash_set { - size_t size; - lm_ggml_bitset_t * used; - struct lm_ggml_tensor ** keys; - }; - - // computation graph - struct lm_ggml_cgraph { - int size; - int n_nodes; - int n_leafs; - - struct lm_ggml_tensor ** nodes; - struct lm_ggml_tensor ** grads; - struct lm_ggml_tensor ** leafs; - - struct lm_ggml_hash_set visited_hash_set; - - enum lm_ggml_cgraph_eval_order order; - }; - - // scratch buffer - struct lm_ggml_scratch { - size_t offs; - size_t size; - void * data; - }; - struct lm_ggml_init_params { // memory pool size_t mem_size; // bytes @@ -717,46 +699,46 @@ extern "C" { LM_GGML_API void lm_ggml_print_object (const struct lm_ggml_object * obj); LM_GGML_API void lm_ggml_print_objects(const struct lm_ggml_context * ctx); - LM_GGML_API LM_GGML_CALL int64_t lm_ggml_nelements (const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL int64_t lm_ggml_nrows (const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL size_t lm_ggml_nbytes (const struct lm_ggml_tensor * tensor); - LM_GGML_API size_t lm_ggml_nbytes_pad (const struct lm_ggml_tensor * tensor); // same as lm_ggml_nbytes() but padded to LM_GGML_MEM_ALIGN + LM_GGML_API int64_t lm_ggml_nelements (const struct lm_ggml_tensor * tensor); + LM_GGML_API int64_t lm_ggml_nrows (const struct lm_ggml_tensor * tensor); + LM_GGML_API size_t lm_ggml_nbytes (const struct lm_ggml_tensor * tensor); + LM_GGML_API size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor); // same as lm_ggml_nbytes() but padded to LM_GGML_MEM_ALIGN - LM_GGML_API LM_GGML_CALL int64_t lm_ggml_blck_size(enum lm_ggml_type type); - LM_GGML_API LM_GGML_CALL size_t lm_ggml_type_size(enum lm_ggml_type type); // size in bytes for all elements in a block - LM_GGML_API LM_GGML_CALL size_t lm_ggml_row_size (enum lm_ggml_type type, int64_t ne); // size in bytes for all elements in a row + LM_GGML_API int64_t lm_ggml_blck_size(enum lm_ggml_type type); + LM_GGML_API size_t lm_ggml_type_size(enum lm_ggml_type type); // size in bytes for all elements in a block + LM_GGML_API size_t lm_ggml_row_size (enum lm_ggml_type type, int64_t ne); // size in bytes for all elements in a row LM_GGML_DEPRECATED( LM_GGML_API double lm_ggml_type_sizef(enum lm_ggml_type type), // lm_ggml_type_size()/lm_ggml_blck_size() as float "use lm_ggml_row_size() instead"); - LM_GGML_API LM_GGML_CALL const char * lm_ggml_type_name(enum lm_ggml_type type); - LM_GGML_API LM_GGML_CALL const char * lm_ggml_op_name (enum lm_ggml_op op); - LM_GGML_API const char * lm_ggml_op_symbol(enum lm_ggml_op op); + LM_GGML_API const char * lm_ggml_type_name(enum lm_ggml_type type); + LM_GGML_API const char * lm_ggml_op_name (enum lm_ggml_op op); + LM_GGML_API const char * lm_ggml_op_symbol(enum lm_ggml_op op); - LM_GGML_API const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op); - LM_GGML_API LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t); // unary or op name + LM_GGML_API const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op); + LM_GGML_API const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t); // unary or op name - LM_GGML_API LM_GGML_CALL size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor); + LM_GGML_API size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_quantized(enum lm_ggml_type type); + LM_GGML_API bool lm_ggml_is_quantized(enum lm_ggml_type type); // TODO: temporary until model loading of ggml examples is refactored LM_GGML_API enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_permuted (const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_empty (const struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_is_scalar (const struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_is_vector (const struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_is_matrix (const struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor); - LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars + LM_GGML_API bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_permuted (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_empty (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_scalar (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_vector (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_matrix (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor); + LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous() - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1 - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2 + LM_GGML_API bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous() + LM_GGML_API bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1 + LM_GGML_API bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2 LM_GGML_API bool lm_ggml_are_same_shape (const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); LM_GGML_API bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); @@ -770,12 +752,12 @@ extern "C" { // main - LM_GGML_API struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params); - LM_GGML_API void lm_ggml_free(struct lm_ggml_context * ctx); + LM_GGML_API struct lm_ggml_context * lm_ggml_init (struct lm_ggml_init_params params); + LM_GGML_API void lm_ggml_reset(struct lm_ggml_context * ctx); + LM_GGML_API void lm_ggml_free (struct lm_ggml_context * ctx); LM_GGML_API size_t lm_ggml_used_mem(const struct lm_ggml_context * ctx); - LM_GGML_API size_t lm_ggml_set_scratch (struct lm_ggml_context * ctx, struct lm_ggml_scratch scratch); LM_GGML_API bool lm_ggml_get_no_alloc(struct lm_ggml_context * ctx); LM_GGML_API void lm_ggml_set_no_alloc(struct lm_ggml_context * ctx, bool no_alloc); @@ -848,7 +830,7 @@ extern "C" { LM_GGML_API void * lm_ggml_get_data (const struct lm_ggml_tensor * tensor); LM_GGML_API float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor); + LM_GGML_API enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor); LM_GGML_API const char * lm_ggml_get_name (const struct lm_ggml_tensor * tensor); LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_name ( struct lm_ggml_tensor * tensor, const char * name); @@ -969,6 +951,22 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_sin( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_sin_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_cos( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_cos_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + // return scalar LM_GGML_API struct lm_ggml_tensor * lm_ggml_sum( struct lm_ggml_context * ctx, @@ -989,6 +987,12 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); + // count number of equal elements in a and b + LM_GGML_API struct lm_ggml_tensor * lm_ggml_count_equal( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b); + // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b LM_GGML_API struct lm_ggml_tensor * lm_ggml_repeat( @@ -1119,6 +1123,14 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_exp( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_exp_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a); + // normalize along rows LM_GGML_API struct lm_ggml_tensor * lm_ggml_norm( struct lm_ggml_context * ctx, @@ -1214,7 +1226,7 @@ extern "C" { size_t nb1, size_t nb2, size_t nb3, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return view(a) LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_inplace( @@ -1224,19 +1236,19 @@ extern "C" { size_t nb1, size_t nb2, size_t nb3, - size_t offset); + size_t offset); // in bytes LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - size_t offset); + size_t offset); // in bytes LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_1d_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return modified a LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_2d( @@ -1244,7 +1256,7 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, size_t nb1, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return view(a) LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_2d_inplace( @@ -1252,7 +1264,7 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, size_t nb1, - size_t offset); + size_t offset); // in bytes // a -> b, return view(b) LM_GGML_API struct lm_ggml_tensor * lm_ggml_cpy( @@ -1387,14 +1399,14 @@ extern "C" { // supports 3D: a->ne[2] == b->ne[1] LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_rows( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b); + struct lm_ggml_tensor * a, // data + struct lm_ggml_tensor * b); // row indices LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_rows_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c); + struct lm_ggml_tensor * a, // gradients of lm_ggml_get_rows result + struct lm_ggml_tensor * b, // row indices + struct lm_ggml_tensor * c); // data for lm_ggml_get_rows, only used for its shape LM_GGML_API struct lm_ggml_tensor * lm_ggml_diag( struct lm_ggml_context * ctx, @@ -1538,16 +1550,16 @@ extern "C" { "use lm_ggml_rope_ext_inplace instead"); // compute correction dims for YaRN RoPE scaling - LM_GGML_CALL void lm_ggml_rope_yarn_corr_dims( + void lm_ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy LM_GGML_API struct lm_ggml_tensor * lm_ggml_rope_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, + struct lm_ggml_tensor * a, // gradients of lm_ggml_rope result + struct lm_ggml_tensor * b, // positions + struct lm_ggml_tensor * c, // freq factors int n_dims, int mode, int n_ctx_orig, @@ -1566,34 +1578,49 @@ extern "C" { float min, float max); + // im2col + // converts data into a format that effectively results in a convolution when combined with matrix multiplication LM_GGML_API struct lm_ggml_tensor * lm_ggml_im2col( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum lm_ggml_type dst_type); + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D, + enum lm_ggml_type dst_type); + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_im2col_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // gradient of im2col output + int64_t * ne, // shape of im2col input + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D); LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data int s0, // stride int p0, // padding int d0); // dilation @@ -1602,29 +1629,29 @@ extern "C" { // alias for lm_ggml_conv_1d(a, b, s, a->ne[0]/2, d) LM_GGML_API struct lm_ggml_tensor* lm_ggml_conv_1d_ph( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s, - int d); + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s, // stride + int d); // dilation LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int p0, - int d0); + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 // kernel size is a->ne[0] x a->ne[1] @@ -1686,6 +1713,18 @@ extern "C" { float p0, float p1); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_pool_2d_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * af, // "a"/input used in forward pass + enum lm_ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + // nearest interpolate // multiplies ne0 and ne1 by scale factor // used in stable-diffusion @@ -1840,6 +1879,15 @@ extern "C" { struct lm_ggml_tensor * pw, struct lm_ggml_tensor * ph); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_rwkv_wkv( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * r, + struct lm_ggml_tensor * tf, + struct lm_ggml_tensor * td, + struct lm_ggml_tensor * state); + // custom operators typedef void (*lm_ggml_unary_op_f32_t) (const int, float *, const float *); @@ -1923,7 +1971,8 @@ extern "C" { typedef void (*lm_ggml_custom2_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, int ith, int nth, void * userdata); typedef void (*lm_ggml_custom3_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, const struct lm_ggml_tensor * c, int ith, int nth, void * userdata); - #define LM_GGML_N_TASKS_MAX -1 +#define LM_GGML_N_TASKS_MAX (-1) + // n_tasks == LM_GGML_N_TASKS_MAX means to use max number of tasks LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom1( struct lm_ggml_context * ctx, @@ -1976,44 +2025,84 @@ extern "C" { // loss function LM_GGML_API struct lm_ggml_tensor * lm_ggml_cross_entropy_loss( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b); + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // logits + struct lm_ggml_tensor * b); // labels LM_GGML_API struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c); + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // logits + struct lm_ggml_tensor * b, // labels + struct lm_ggml_tensor * c); // gradients of cross_entropy_loss result + + // AdamW optimizer step + // Paper: https://arxiv.org/pdf/1711.05101v3.pdf + // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_step_adamw( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * grad, + float alpha, + float beta1, + float beta2, + float eps, + float wd); // weight decay // // automatic differentiation // - LM_GGML_API void lm_ggml_set_param( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * tensor); - + LM_GGML_API void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor); LM_GGML_API void lm_ggml_build_forward_expand (struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep); + LM_GGML_API void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate); + + LM_GGML_API void lm_ggml_build_opt_adamw( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * gf, + struct lm_ggml_cgraph * gb, + float alpha, + float beta1, + float beta2, + float eps, + float wd); // weight decay // graph allocation in a context - LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false - LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph_custom (struct lm_ggml_context * ctx, size_t size, bool grads); - LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph); - LM_GGML_API struct lm_ggml_cgraph lm_ggml_graph_view (struct lm_ggml_cgraph * cgraph, int i0, int i1); - LM_GGML_API void lm_ggml_graph_cpy (struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst); - LM_GGML_API void lm_ggml_graph_reset (struct lm_ggml_cgraph * cgraph); // zero grads - LM_GGML_API void lm_ggml_graph_clear (struct lm_ggml_cgraph * cgraph); + LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false + LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads); + LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph); + LM_GGML_API void lm_ggml_graph_cpy (struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst); + LM_GGML_API void lm_ggml_graph_reset (struct lm_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 + LM_GGML_API void lm_ggml_graph_clear (struct lm_ggml_cgraph * cgraph); + + LM_GGML_API int lm_ggml_graph_size (struct lm_ggml_cgraph * cgraph); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_node (struct lm_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i] + LM_GGML_API struct lm_ggml_tensor ** lm_ggml_graph_nodes (struct lm_ggml_cgraph * cgraph); + LM_GGML_API int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph); + + LM_GGML_API void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor); LM_GGML_API size_t lm_ggml_graph_overhead(void); LM_GGML_API size_t lm_ggml_graph_overhead_custom(size_t size, bool grads); + LM_GGML_API struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads); + LM_GGML_API void lm_ggml_threadpool_params_init (struct lm_ggml_threadpool_params * p, int n_threads); + LM_GGML_API bool lm_ggml_threadpool_params_match (const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1); + LM_GGML_API struct lm_ggml_threadpool * lm_ggml_threadpool_new (struct lm_ggml_threadpool_params * params); + LM_GGML_API void lm_ggml_threadpool_free (struct lm_ggml_threadpool * threadpool); + LM_GGML_API int lm_ggml_threadpool_get_n_threads(struct lm_ggml_threadpool * threadpool); + LM_GGML_API void lm_ggml_threadpool_pause (struct lm_ggml_threadpool * threadpool); + LM_GGML_API void lm_ggml_threadpool_resume (struct lm_ggml_threadpool * threadpool); + // lm_ggml_graph_plan() has to be called before lm_ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data - LM_GGML_API struct lm_ggml_cplan lm_ggml_graph_plan (const struct lm_ggml_cgraph * cgraph, int n_threads /*= LM_GGML_DEFAULT_N_THREADS*/); - LM_GGML_API enum lm_ggml_status lm_ggml_graph_compute( struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan); + LM_GGML_API struct lm_ggml_cplan lm_ggml_graph_plan( + const struct lm_ggml_cgraph * cgraph, + int n_threads, /* = LM_GGML_DEFAULT_N_THREADS */ + struct lm_ggml_threadpool * threadpool /* = NULL */ ); + LM_GGML_API enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan); + // same as lm_ggml_graph_compute() but the work data is allocated as a part of the context // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data LM_GGML_API enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads); @@ -2077,6 +2166,10 @@ extern "C" { typedef void (*lm_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); typedef void (*lm_ggml_log_callback)(enum lm_ggml_log_level level, const char * text, void * user_data); + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LM_GGML_API void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data); + // optimization parameters // // see ggml.c (lm_ggml_opt_default_params) for default values @@ -2387,6 +2480,7 @@ extern "C" { LM_GGML_API int lm_ggml_cpu_has_avx512_vbmi(void); LM_GGML_API int lm_ggml_cpu_has_avx512_vnni(void); LM_GGML_API int lm_ggml_cpu_has_avx512_bf16(void); + LM_GGML_API int lm_ggml_cpu_has_amx_int8 (void); LM_GGML_API int lm_ggml_cpu_has_fma (void); LM_GGML_API int lm_ggml_cpu_has_neon (void); LM_GGML_API int lm_ggml_cpu_has_sve (void); @@ -2402,6 +2496,7 @@ extern "C" { LM_GGML_API int lm_ggml_cpu_has_gpublas (void); LM_GGML_API int lm_ggml_cpu_has_sse3 (void); LM_GGML_API int lm_ggml_cpu_has_ssse3 (void); + LM_GGML_API int lm_ggml_cpu_has_riscv_v (void); LM_GGML_API int lm_ggml_cpu_has_sycl (void); LM_GGML_API int lm_ggml_cpu_has_rpc (void); LM_GGML_API int lm_ggml_cpu_has_vsx (void); @@ -2409,6 +2504,9 @@ extern "C" { LM_GGML_API int lm_ggml_cpu_has_cann (void); LM_GGML_API int lm_ggml_cpu_has_llamafile (void); + // get the sve vector length in bytes + LM_GGML_API int lm_ggml_cpu_get_sve_cnt(void); + // // Internal types and functions exposed for tests and benchmarks // @@ -2430,7 +2528,7 @@ extern "C" { typedef void (*lm_ggml_gemm_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, const void * LM_GGML_RESTRICT y, int nr, int nc); - typedef struct { + struct lm_ggml_type_traits { const char * type_name; int64_t blck_size; int64_t blck_size_interleave; // interleave elements in blocks @@ -2446,9 +2544,9 @@ extern "C" { int64_t ncols; // number of columns to process simultaneously lm_ggml_gemv_t gemv; lm_ggml_gemm_t gemm; - } lm_ggml_type_traits_t; + }; - LM_GGML_API lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type); + LM_GGML_API const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type); #ifdef __cplusplus } diff --git a/cpp/grammar-parser.cpp b/cpp/grammar-parser.cpp deleted file mode 100644 index 438452e..0000000 --- a/cpp/grammar-parser.cpp +++ /dev/null @@ -1,539 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - if (rule.empty()) { - throw std::runtime_error("Undefined rule"); - } - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); - return parse_state(); - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/cpp/grammar-parser.h b/cpp/grammar-parser.h deleted file mode 100644 index 9037d72..0000000 --- a/cpp/grammar-parser.h +++ /dev/null @@ -1,29 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/cpp/json-schema-to-grammar.cpp b/cpp/json-schema-to-grammar.cpp index 881eb49..dadc18c 100644 --- a/cpp/json-schema-to-grammar.cpp +++ b/cpp/json-schema-to-grammar.cpp @@ -611,7 +611,7 @@ class SchemaConverter { } return join_seq(); }; - return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); } /* diff --git a/cpp/llama-grammar.cpp b/cpp/llama-grammar.cpp index bb38bd5..6cc2b38 100644 --- a/cpp/llama-grammar.cpp +++ b/cpp/llama-grammar.cpp @@ -3,11 +3,31 @@ #include "llama-vocab.h" #include "llama-sampling.h" +#include #include +#include -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +// +// helpers +// + +// NOTE: assumes valid utf8 (but checks for overrun) +static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -40,7 +60,7 @@ std::pair, llama_partial_utf8> decode_utf8( while (*pos != 0) { uint8_t first_byte = static_cast(*pos); uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; + n_remain = lookup[highbits] - 1; if (n_remain < 0) { // invalid sequence, abort @@ -50,7 +70,7 @@ std::pair, llama_partial_utf8> decode_utf8( } uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; + value = first_byte & mask; ++pos; while (*pos != 0 && n_remain > 0) { @@ -67,12 +87,510 @@ std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } -const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { - return grammar->rules; +static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; } -llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { - return grammar->stacks; +static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +} + +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); +} + +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; +} + +static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; +} + +static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; +} + +static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); +} + +static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } +} + +static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; + } +} + +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); +} + +static void print_rule( + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); +} + +// +// implementation +// + +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(src, len), next_id); + return result.first->second; +} + +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { + uint32_t next_id = static_cast(symbol_ids.size()); + symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; +} + +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = rule; +} + +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + llama_grammar_rule rule; + const char * pos = parse_sequence(src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(rule_id, rule); + return pos; +} + +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { + size_t last_sym_start = rule.size(); + const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } + } + return pos; + } + +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + +bool llama_grammar_parser::parse(const char * src) { + try { + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(pos); + } + // Validate the state to ensure that all rules are defined + for (const auto & rule : rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } + for (const auto & elem : rule) { + if (elem.type == LLAMA_GRETYPE_RULE_REF) { + // Ensure that the rule at that location exists + if (elem.value >= rules.size() || rules[elem.value].empty()) { + // Get the name of the rule that is missing + for (const auto & kv : symbol_ids) { + if (kv.second == elem.value) { + throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); + } + } + } + } + } + } + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + rules.clear(); + return false; + } + + return true; +} + +void llama_grammar_parser::print(FILE * file) { + try { + std::map symbol_id_names; + for (const auto & kv : symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, rules[i]); + print_rule(file, uint32_t(i), rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } +} + +llama_grammar_stack llama_grammar_parser::c_rules() const { + llama_grammar_stack ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; } // returns true iff pos points to the end of one of the definitions of a rule @@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) static std::pair llama_grammar_match_char( const llama_grammar_element * pos, const uint32_t chr) { - bool found = false; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; @@ -225,16 +742,93 @@ static void llama_grammar_advance_stack( } } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions +static llama_grammar_candidates llama_grammar_reject_candidates( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { + LM_GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return {}; + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + + return rejects; +} + +static bool llama_grammar_detect_left_recursion( + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const llama_grammar_rule & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + + return false; +} + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); + llama_grammar_stacks & stacks_new) { + stacks_new.clear(); + stacks_new.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -250,29 +844,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, stacks_new); } } } -static llama_grammar_candidates llama_grammar_reject_candidates( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const llama_grammar_candidates & candidates) { - LM_GGML_ASSERT(!stacks.empty()); // REVIEW - - if (candidates.empty()) { - return {}; - } - - auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); - - for (size_t i = 1, size = stacks.size(); i < size; ++i) { - rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); - } - return rejects; -} - llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, @@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( return rejects; } -static bool llama_grammar_detect_left_recursion( - const llama_grammar_rules & rules, - size_t rule_index, - std::vector * rules_visited, - std::vector * rules_in_progress, - std::vector * rules_may_be_empty) { - if ((*rules_in_progress)[rule_index]) { - return true; - } +//////////////////// - (*rules_in_progress)[rule_index] = true; +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; - const llama_grammar_rule & rule = rules[rule_index]; + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } - // First check if the rule might produce the empty string. This could be done combined with the second - // step but it's more readable as two steps. - bool at_rule_start = true; - for (size_t i = 0; i < rule.size(); i++) { - if (llama_grammar_is_end_of_sequence(&rule[i])) { - if (at_rule_start) { - (*rules_may_be_empty)[rule_index] = true; - break; - } - at_rule_start = true; - } else { - at_rule_start = false; + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; } } - // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may - // be empty) - bool recurse_into_nonterminal = true; - for (size_t i = 0; i < rule.size(); i++) { - if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { - if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { - return true; - } - if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { - recurse_into_nonterminal = false; - } - } else if (llama_grammar_is_end_of_sequence(&rule[i])) { - recurse_into_nonterminal = true; + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; } else { - recurse_into_nonterminal = false; + break; } - } + } while (true); - (*rules_in_progress)[rule_index] = false; - (*rules_visited)[rule_index] = true; - return false; + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } -// -// grammar - external -// +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + llama_grammar_parser parser; + + // if there is a grammar, parse it + if (!parser.parse(grammar_str)) { + return nullptr; + } + + // will be empty (default) if there are parse errors + if (parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + // Ensure that there is a "root" node. + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + return nullptr; + } + + std::vector grammar_rules(parser.c_rules()); + + const size_t n_rules = grammar_rules.size(); + const size_t start_rule_index = parser.symbol_ids.at(grammar_root); -struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { const llama_grammar_element * pos; // copy rule definitions into vectors llama_grammar_rules vec_rules(n_rules); for (size_t i = 0; i < n_rules; i++) { - for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); } vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); @@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { + llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { result->stacks[is][ie] = &result->rules[ir0][ir1]; } } @@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - LM_GGML_ASSERT(grammar); - LM_GGML_ASSERT(vocab); - - int64_t t_start_sample_us = lm_ggml_time_us(); +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { + LM_GGML_ASSERT(grammar.vocab != nullptr); bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc } std::vector, llama_partial_utf8>> candidates_decoded; - candidates_decoded.reserve(candidates->size); + candidates_decoded.reserve(cur_p->size); llama_grammar_candidates candidates_grammar; - candidates_grammar.reserve(candidates->size); + candidates_grammar.reserve(cur_p->size); - for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + for (size_t i = 0; i < cur_p->size; ++i) { + const llama_token id = cur_p->data[i].id; + const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (!allow_eog) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; + cur_p->data[reject.index].logit = -INFINITY; } - - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = lm_ggml_time_us(); +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + LM_GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*vocab, token)) { - for (const auto & stack : grammar->stacks) { + if (llama_token_is_eog_impl(*grammar.vocab, token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc LM_GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; + llama_grammar_stacks stacks_new; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - grammar->stacks = tmp_new_stacks; + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + grammar.stacks = std::move(stacks_new); } - grammar->partial_utf8 = decoded.second; - LM_GGML_ASSERT(!grammar->stacks.empty()); - - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + grammar.partial_utf8 = decoded.second; + LM_GGML_ASSERT(!grammar.stacks.empty()); } diff --git a/cpp/llama-grammar.h b/cpp/llama-grammar.h index 695ea06..f529ce3 100644 --- a/cpp/llama-grammar.h +++ b/cpp/llama-grammar.h @@ -2,11 +2,115 @@ #include "llama-impl.h" +#include + struct llama_vocab; -struct llama_sampling; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + uint32_t chr, + llama_grammar_stacks & stacks_new); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +struct llama_grammar_parser { + std::map symbol_ids; + + llama_grammar_rules rules; + + llama_grammar_stack c_rules() const; + + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); + + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); + + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); + + const char * parse_rule(const char * src); + + bool parse(const char * src); + void print(FILE * file); +}; struct llama_grammar { - const llama_grammar_rules rules; + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; + + const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens @@ -17,23 +121,24 @@ struct llama_grammar { // internal API // +// note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); -void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, - llama_token_data_array * candidates); +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, + llama_token_data_array * cur_p); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, llama_token token); diff --git a/cpp/llama-impl.h b/cpp/llama-impl.h index ac6ce52..3a06a2c 100644 --- a/cpp/llama-impl.h +++ b/cpp/llama-impl.h @@ -1,8 +1,11 @@ #pragma once -#define LLAMA_API_INTERNAL #include "llama.h" +#include +#include +#include + #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -21,21 +24,158 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3) void llama_log_internal (lm_ggml_log_level level, const char * format, ...); void llama_log_callback_default(lm_ggml_log_level level, const char * text, void * user_data); +#define LLAMA_LOG(...) llama_log_internal(LM_GGML_LOG_LEVEL_NONE , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_DEBUG(...) llama_log_internal(LM_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LLAMA_LOG_CONT(...) llama_log_internal(LM_GGML_LOG_LEVEL_CONT , __VA_ARGS__) // // helpers // +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : lm_ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + if (t_start_us >= 0) { + t_acc += lm_ggml_time_us() - t_start_us; + } + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; diff --git a/cpp/llama-sampling.cpp b/cpp/llama-sampling.cpp index c76a75b..6d298a4 100644 --- a/cpp/llama-sampling.cpp +++ b/cpp/llama-sampling.cpp @@ -1,12 +1,53 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" + #include +#include +#include +#include +#include +#include #include #include -#include #include +#include #include +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { + // iterator for the probabilities +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + + struct probs_iterator { + typedef std::input_iterator_tag iterator_category; + typedef float value_type; + typedef float * pointer; + typedef float & reference; + typedef ptrdiff_t difference_type; + + const llama_token_data * data; + + bool operator==(const probs_iterator & other) const { return data == other.data; } + bool operator!=(const probs_iterator & other) const { return data != other.data; } + const float & operator*() const { return data->p; } + probs_iterator & operator++() { ++data; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } + }; + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif + + std::discrete_distribution dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); + + return dist(rng); +} + +/* static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -20,66 +61,76 @@ static void llama_log_softmax(float * array, size_t size) { array[i] = logf(array[i] / sum); } } +*/ + +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; + } + } -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); + return; } - smpl->rng.seed(seed); + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } } -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - LM_GGML_ASSERT(candidates->size > 0); - - const int64_t t_start_sample_us = lm_ggml_time_us(); +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { + LM_GGML_ASSERT(cur_p->size > 0); // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }); - candidates->sorted = true; + cur_p->sorted = true; } - float max_l = candidates->data[0].logit; + float max_l = cur_p->data[0].logit; float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float p = expf(candidates->data[i].logit - max_l); - candidates->data[i].p = p; + + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; cum_sum += p; } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum; - } - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum; } } -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { +static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { + // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast + // if (k >= (int32_t)cur_p->size) { // return; // } - const int64_t t_start_sample_us = lm_ggml_time_us(); - if (k <= 0) { - k = candidates->size; + k = cur_p->size; } - k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); + k = std::min(k, (int) cur_p->size); // Sort scores in descending order - if (!candidates->sorted) { + if (!cur_p->sorted) { auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); } else { constexpr int nbuckets = 128; constexpr float bucket_low = -10.0f; @@ -87,11 +138,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); constexpr float bucket_inter = -bucket_low * bucket_scale; - std::vector bucket_idx(candidates->size); + std::vector bucket_idx(cur_p->size); std::vector histo(nbuckets, 0); - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; + for (int i = 0; i < (int)cur_p->size; ++i) { + const float val = cur_p->data[i].logit; int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); ib = std::max(0, std::min(nbuckets-1, ib)); bucket_idx[i] = ib; @@ -101,20 +152,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra int ib = nbuckets - 1; for ( ; ib >= 0; --ib) { nhave += histo[ib]; - if (nhave >= k) break; + if (nhave >= k) { + break; + } } std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); + auto * ptr = tmp_tokens.data(); std::vector bucket_ptrs; bucket_ptrs.reserve(nbuckets - ib); for (int j = nbuckets - 1; j >= ib; --j) { bucket_ptrs.push_back(ptr); ptr += histo[j]; } - for (int i = 0; i < (int)candidates->size; ++i) { + for (int i = 0; i < (int)cur_p->size; ++i) { int j = bucket_idx[i]; if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i]; } } @@ -127,196 +180,596 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra } std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + cur_p->sorted = true; + } + cur_p->size = k; +} +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); } - candidates->sorted = true; + std::random_device rd; + return rd(); + } + return seed; +} + +// llama_sampler API + +const char * llama_sampler_name(const struct llama_sampler * smpl) { + if (!smpl->iface) { + return "(null)"; } - candidates->size = k; - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + return smpl->iface->name(smpl); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (smpl->iface->accept) { + smpl->iface->accept(smpl, token); + } +} + +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + LM_GGML_ASSERT(smpl->iface->apply); + smpl->iface->apply(smpl, cur_p); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + if (smpl->iface->reset) { + smpl->iface->reset(smpl); + } +} + +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (smpl->iface->clone) { + return smpl->iface->clone(smpl); + } + + if (smpl->ctx == nullptr) { + return new llama_sampler { + /* .iface = */ smpl->iface, + /* .ctx = */ nullptr, + }; } + + LM_GGML_ABORT("the sampler does not support cloning"); } -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p >= 1.0f) { +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { return; } - llama_sample_softmax_impl(smpl, candidates); + if (smpl->iface->free) { + smpl->iface->free(smpl); + } + + delete smpl; +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + // TODO: do not allocate each time + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + +// sampler chain + +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; +} + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return new llama_sampler { + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }, + }; +} + +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + auto * p = (llama_sampler_chain *) chain->ctx; + p->samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + return p->samplers[i]; +} + +struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); + + return result; +} + +int llama_sampler_chain_n(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + return p->samplers.size(); +} + +// +// samplers +// + +// greedy + +static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { + return "greedy"; +} + +static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } +} + +static struct llama_sampler_i llama_sampler_greedy_i = { + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_greedy() { + return new llama_sampler { + /* .iface = */ &llama_sampler_greedy_i, + /* .ctx = */ nullptr, + }; +} + +// dist + +struct llama_sampler_dist { + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { + return "dist"; +} + +static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); +} + +static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_dist *) smpl->ctx; + auto * result = llama_sampler_init_dist(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static void llama_sampler_dist_free(struct llama_sampler * smpl) { + delete (llama_sampler_dist *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, +}; + +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_dist { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + +// softmax + +static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) { + return "softmax"; +} + +static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + llama_sampler_softmax_impl(cur_p); +} + +static struct llama_sampler_i llama_sampler_softmax_i = { + /* .name = */ llama_sampler_softmax_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_softmax_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_softmax() { + return new llama_sampler { + /* .iface = */ &llama_sampler_softmax_i, + /* .ctx = */ nullptr, + }; +} + +// top-k + +struct llama_sampler_top_k { + const int32_t k; +}; + +static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { + return "top-k"; +} + +static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); +} + +static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; + return llama_sampler_init_top_k(ctx->k); +} + +static void llama_sampler_top_k_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_k *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_k_i = { + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, +}; + +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_k_i, + /* .ctx = */ new llama_sampler_top_k { + /* .k = */ k, + }, + }; +} + +// top-p + +struct llama_sampler_top_p { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { + return "top-p"; +} + +static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_p *) smpl->ctx; + + if (ctx->p >= 1.0f) { + return; + } - const int64_t t_start_sample_us = lm_ggml_time_us(); + llama_sampler_softmax_impl(cur_p); // Compute the cumulative probabilities float cum_sum = 0.0f; - size_t last_idx = candidates->size; + size_t last_idx = cur_p->size; - for (size_t i = 0; i < candidates->size; ++i) { - cum_sum += candidates->data[i].p; + for (size_t i = 0; i < cur_p->size; ++i) { + cum_sum += cur_p->data[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) { + if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) { last_idx = i + 1; break; } } // Resize the output vector to keep only the top-p tokens - candidates->size = last_idx; + cur_p->size = last_idx; +} - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } +static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_top_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_p_i = { + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, +}; + +struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_p_i, + /* .ctx = */ new llama_sampler_top_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + }, + }; } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p <= 0.0f || !candidates->size) { +// min-p + +struct llama_sampler_min_p { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { + return "min-p"; +} + +static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_min_p *) smpl->ctx; + + if (ctx->p <= 0.0f || !cur_p->size) { return; } - const int64_t t_start_sample_us = lm_ggml_time_us(); - bool min_p_applied = false; - // if the candidates aren't sorted, try the unsorted implementation first - if (!candidates->sorted) { + // if the cur_p aren't sorted, try the unsorted implementation first + if (!cur_p->sorted) { std::vector filtered_tokens; float max_logit = -FLT_MAX; - for (size_t i = 0; i < candidates->size; ++i) { - max_logit = std::max(max_logit, candidates->data[i].logit); + for (size_t i = 0; i < cur_p->size; ++i) { + max_logit = std::max(max_logit, cur_p->data[i].logit); } - const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max - for (size_t i = 0; i < candidates->size; ++i) { - if (candidates->data[i].logit >= min_logit) { - filtered_tokens.push_back(candidates->data[i]); + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit >= min_logit) { + filtered_tokens.push_back(cur_p->data[i]); } } // if we have enough values the operation was a success - if (filtered_tokens.size() >= min_keep) { - memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); - candidates->size = filtered_tokens.size(); + if (filtered_tokens.size() >= ctx->min_keep) { + memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + cur_p->size = filtered_tokens.size(); min_p_applied = true; } } - // if the candidates are sorted or the unsorted implementation failed, use this implementation + // if the cur_p are sorted or the unsorted implementation failed, use this implementation if (!min_p_applied) { // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }); - candidates->sorted = true; + cur_p->sorted = true; } - const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max size_t i = 1; // first token always matches - for (; i < candidates->size; ++i) { - if (candidates->data[i].logit < min_logit && i >= min_keep) { + for (; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) { break; // prob too small } } // Resize the output vector to keep only the matching tokens - candidates->size = i; - } - - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + cur_p->size = i; } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { - if (z >= 1.0f || candidates->size <= 2) { - return; - } - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = lm_ggml_time_us(); - - // Compute the first and second derivatives - std::vector first_derivatives(candidates->size - 1); - std::vector second_derivatives(candidates->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); +static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); +} - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } +static void llama_sampler_min_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_min_p *) smpl->ctx; +} - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; +static struct llama_sampler_i llama_sampler_min_p_i = { + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, +}; + +struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_min_p_i, + /* .ctx = */ new llama_sampler_min_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + }, + }; +} - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) { - last_idx = i; - break; - } - } +// typical - // Resize the output vector to keep only the tokens above the tail location - candidates->size = last_idx; +struct llama_sampler_typical { + const float p; + const size_t min_keep; +}; - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } +static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) { + return "typical"; } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_typical *) smpl->ctx; + // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) { + if (ctx->p >= 1.0f) { return; } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = lm_ggml_time_us(); + llama_sampler_softmax_impl(cur_p); float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - entropy += -candidates->data[i].p * logf(candidates->data[i].p); + for (size_t i = 0; i < cur_p->size; ++i) { + entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); } // Compute the absolute difference between negative log probability and entropy for each candidate std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + for (size_t i = 0; i < cur_p->size; ++i) { + float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); shifted_scores.push_back(shifted_score); } // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(candidates->size); + std::vector indices(cur_p->size); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { @@ -329,197 +782,340 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar for (size_t i = 0; i < indices.size(); ++i) { size_t idx = indices[i]; - cum_sum += candidates->data[idx].p; + cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) { + if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { last_idx = i + 1; break; } } // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; + std::vector cur_p_new; for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); + cur_p_new.push_back(cur_p->data[idx]); } - // Replace the data in candidates with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); - candidates->size = new_candidates.size(); - candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } + // Replace the data in cur_p with the cur_p_new data + std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); + cur_p->size = cur_p_new.size(); + cur_p->sorted = false; } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = lm_ggml_time_us(); - - // no need to do anything if there is only one (or zero) candidates - if(candidates->size <= 1) { - return; - } +static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_typical *) smpl->ctx; + return llama_sampler_init_typical(ctx->p, ctx->min_keep); +} - // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates->size); +static void llama_sampler_typical_free(struct llama_sampler * smpl) { + delete (llama_sampler_typical *) smpl->ctx; +} - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +static struct llama_sampler_i llama_sampler_typical_i = { + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, +}; + +struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_typical_i, + /* .ctx = */ new llama_sampler_typical { + /* .p = */ p, + /* .min_keep = */ min_keep, + }, + }; +} - // Calculate entropy of the softmax probabilities - float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float prob = candidates->data[i].p; - if (prob > 0.0f) { // Ensure no log(0) - entropy -= prob * logf(prob); - } - } +// temp - // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above) - float normalized_entropy = entropy / max_entropy; +struct llama_sampler_temp { + const float temp; +}; - // Map the normalized entropy to the desired temperature range using the power function - float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); +static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { + return "temp"; +} -#ifdef DEBUG - LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); - LLAMA_LOG_INFO("Entropy: %f\n", entropy); - LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); - LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); - LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); - LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); -#endif +static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp *) smpl->ctx; - // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= dyn_temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); +} - // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates->data[0].logit; - double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates->size; ++i) { - double p = exp(candidates->data[i].logit - max_l_double); - candidates->data[i].p = p; // Store the scaled probability - cum_sum_double += p; - } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities - } +static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp *) smpl->ctx; + return llama_sampler_init_temp(ctx->temp); +} -#ifdef DEBUG - // Print the updated top 25 probabilities after temperature scaling - LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < candidates->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); - } -#endif +static void llama_sampler_temp_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp *) smpl->ctx; +} - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } +static struct llama_sampler_i llama_sampler_temp_i = { + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, +}; + +struct llama_sampler * llama_sampler_init_temp(float temp) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_i, + /* .ctx = */ new llama_sampler_temp { + /*.temp = */ temp, + }, + }; } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = lm_ggml_time_us(); +// temp-ext - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= temp; - } +struct llama_sampler_temp_ext { + const float temp; + const float delta; + const float exponent; +}; - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { + return "temp-ext"; } -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = lm_ggml_time_us(); +static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + if (ctx->delta > 0) { + const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); + const float max_temp = ctx->temp + ctx->delta; - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[last_tokens[i]]++; - } + float exponent_val = ctx->exponent; - // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) { - const auto token_iter = token_count.find(candidates->data[i].id); - if (token_iter == token_count.end()) { - continue; + // no need to do anything if there is only one (or zero) candidates + if (cur_p->size <= 1) { + return; } - const int count = token_iter->second; + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / cur_p->size); - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty_repeat; - } else { - candidates->data[i].logit /= penalty_repeat; + llama_sampler_softmax_impl(cur_p); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float prob = cur_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } } - candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; - } + // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) + float normalized_entropy = entropy / max_entropy; - candidates->sorted = false; + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } -} + #ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); + #endif -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale) { - LM_GGML_ASSERT(smpl); + // Apply the dynamically calculated temperature scaling + llama_sampler_temp_impl(cur_p, dyn_temp); - const auto t_start_sample_us = lm_ggml_time_us(); - const auto n_vocab = smpl->n_vocab; + // Re-compute softmax probabilities after scaling logits with dynamic temperature + const double max_l_double = cur_p->data[0].logit; - llama_log_softmax(logits, n_vocab); - llama_log_softmax(logits_guidance, n_vocab); + double cum_sum_double = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + double p = exp(cur_p->data[i].logit - max_l_double); + cur_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } - for (int i = 0; i < n_vocab; ++i) { - auto & l = logits[i]; - const auto & g = logits_guidance[i]; + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } - l = scale * (l - g) + g; + #ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); + } + #endif + } else { + llama_sampler_temp_impl(cur_p, ctx->temp); } +} - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; +static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - LM_GGML_ASSERT(smpl); +static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp_ext *) smpl->ctx; +} - const int32_t n_vocab = float(smpl->n_vocab); +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, +}; + +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_ext_i, + /* .ctx = */ new llama_sampler_temp_ext { + /* .temp = */ temp, + /* .delta = */ delta, + /* .exponent = */ exponent, + }, + }; +} - int64_t t_start_sample_us = lm_ggml_time_us(); +// xtc - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +struct llama_sampler_xtc { + const float probability; + const float threshold; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { + return "xtc"; +} + +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + + if (ctx->probability <= 0.0f + || ctx->threshold > 0.5f + || cur_p->size < 2) { + return; + } + + std::uniform_real_distribution distribution(0.0f, 1.0f); + float chance = distribution(ctx->rng); + if (chance > ctx->probability) return; + + // in case it's not sorted/recalculated yet + llama_sampler_softmax_impl(cur_p); + + int pos_last = 0; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { + pos_last = i; + } else break; + } + + if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { + cur_p->data += pos_last; + cur_p->size -= pos_last; + } +} + +static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_xtc_free(struct llama_sampler * smpl) { + delete (llama_sampler_xtc *) smpl->ctx; +} + +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler_i llama_sampler_xtc_i = { + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, +}; + +struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_xtc_i, + /* .ctx = */ new llama_sampler_xtc { + /* .probability = */ p, + /* .threshold = */ t, + /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + +// mirostat + +struct llama_sampler_mirostat { + const int32_t n_vocab; + + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + const int32_t m; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { + return "mirostat"; +} + +static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; } @@ -527,109 +1123,1222 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); - // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = lm_ggml_time_us(); + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_softmax_impl(cur_p); - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; + const int idx = llama_sample_dist(cur_p, ctx->rng); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + ctx->mu = ctx->mu - ctx->eta * e; +} + +static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - return X; + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = lm_ggml_time_us(); +static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} - llama_sample_softmax_impl(smpl, candidates); +static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_i, + /* .ctx = */ new llama_sampler_mirostat { + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + +// mirostat v2 + +struct llama_sampler_mirostat_v2 { + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { + return "mirostat-v2"; +} + +static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; })); - if (candidates->size == 0) { - candidates->size = 1; - } - - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + if (cur_p->size == 0) { + cur_p->size = 1; } // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampler_softmax_impl(cur_p); - // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = lm_ggml_time_us(); + const int idx = llama_sample_dist(cur_p, ctx->rng); - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + ctx->mu = ctx->mu - ctx->eta * e; +} + +static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; + + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; } - return X; + + return result; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = lm_ggml_time_us(); +static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat_v2 *) smpl->ctx; +} - // Find max element - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_v2_i, + /* .ctx = */ new llama_sampler_mirostat_v2 { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + +// grammar + +struct llama_sampler_grammar { + const struct llama_vocab * vocab; + + std::string grammar_str; + std::string grammar_root; + + struct llama_grammar * grammar; +}; + +static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) { + return "grammar"; +} + +static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } +} + +static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_apply_impl(*ctx->grammar, cur_p); + } +} - llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - smpl->n_sample++; +static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (!ctx->grammar) { + return; } + + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} + +static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; + + auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); + + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; + + if (ctx->grammar) { + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; + + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); + } + } + return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - LM_GGML_ASSERT(smpl); +static void llama_sampler_grammar_free(struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; - const int64_t t_start_sample_us = lm_ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + } + + delete ctx; +} + +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, +}; + +struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + auto * ctx = new llama_sampler_grammar; + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + *ctx = { + /* .vocab = */ &vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + }; + } else { + *ctx = { + /* .vocab = */ &vocab, + /* .grammar_str = */ {}, + /* .grammar_root = */ {}, + /* .grammar = */ nullptr, + }; + } + + return new llama_sampler { + /* .iface = */ &llama_sampler_grammar_i, + /* .ctx = */ ctx, + }; +} + +// penalties + +struct llama_sampler_penalties { + const int32_t n_vocab; + const llama_token special_eos_id; + const llama_token linefeed_id; + + const int32_t penalty_last_n; + const float penalty_repeat; + const float penalty_freq; + const float penalty_present; + + const bool penalize_nl; + const bool ignore_eos; + + ring_buffer prev; +}; + +static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { + return "penalties"; +} + +static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + if (ctx->penalty_last_n == 0) { + return; + } + + ctx->prev.push_back(token); +} + +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + + if (ctx->ignore_eos) { + assert(ctx->special_eos_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { + cur_p->data[ctx->special_eos_id].logit = -INFINITY; + } else { + // else, search for the special EOS token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->special_eos_id) { + cur_p->data[i].logit = -INFINITY; + break; + } + } + } + } + + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } + + bool nl_found = false; + size_t nl_idx = 0; + float nl_logit = -INFINITY; + if (!ctx->penalize_nl) { + assert(ctx->linefeed_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = ctx->linefeed_id; + nl_logit = cur_p->data[ctx->linefeed_id].logit; + } else { + // else, search for the linefeed token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = i; + nl_logit = cur_p->data[i].logit; + break; + } + } + } + } + + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: optimize this by maintaining the token count in the sampler context + using llama_token_cnt = std::unordered_map; + llama_token_cnt token_count; + + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + token_count[ctx->prev.rat(i)]++; + } + + // Apply frequency and presence penalties to the cur_p + for (size_t i = 0; i < cur_p->size; ++i) { + const auto token_iter = token_count.find(cur_p->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (cur_p->data[i].logit <= 0) { + cur_p->data[i].logit *= ctx->penalty_repeat; + } else { + cur_p->data[i].logit /= ctx->penalty_repeat; + } - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; } - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); + cur_p->sorted = false; - llama_token result = candidates->data[idx].id; + if (!ctx->penalize_nl && nl_found) { + // restore the logit of the newline token if it was penalized + cur_p->data[nl_idx].logit = nl_logit; + } +} - smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - smpl->n_sample++; +static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + ctx->prev.clear(); +} + +static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties( + ctx->n_vocab, + ctx->special_eos_id, + ctx->linefeed_id, + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present, + ctx->penalize_nl, + ctx->ignore_eos); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +static void llama_sampler_penalties_free(struct llama_sampler * smpl) { + delete (llama_sampler_penalties *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, +}; + +struct llama_sampler * llama_sampler_init_penalties( + int32_t n_vocab, + llama_token special_eos_id, + llama_token linefeed_id, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + if (linefeed_id == LLAMA_TOKEN_NULL) { + penalize_nl = true; + } + + if (special_eos_id == LLAMA_TOKEN_NULL) { + ignore_eos = false; + } + + penalty_last_n = std::max(penalty_last_n, 0); + + return new llama_sampler { + /* .iface = */ &llama_sampler_penalties_i, + /* .ctx = */ new llama_sampler_penalties { + /* .n_vocab = */ n_vocab, + /* .special_eos_id = */ special_eos_id, + /* .linefeed_id = */ linefeed_id, + /* .penalty_last_n = */ penalty_last_n, + /* .penalty_repeat = */ penalty_repeat, + /* .penalty_freq = */ penalty_freq, + /* .penalty_present = */ penalty_present, + /* .penalize_nl = */ penalize_nl, + /* .ignore_eos = */ ignore_eos, + /* .prev = */ ring_buffer(penalty_last_n), + }, + }; +} + +// DRY + +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) { + std::string word = llama_detokenize(vocab, {token_id}, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector()); + } else { + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { + return "dry"; +} + +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + ctx->last_tokens.push_back(token); +} + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); + + if (last_n_repeat <= ctx->dry_allowed_length) { + return; + } + + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < ctx->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { + ++n; + } + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k+n-1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + ctx->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} + +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; + + // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying + auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler { + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), + }, + }; +} + +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; + + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } + + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } + + return result; +} + +// logit-bias + +struct llama_sampler_logit_bias { + const int32_t n_vocab; + + const std::vector logit_bias; + + std::vector to_search; +}; + +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { + return "logit-bias"; +} + +static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + + if (ctx->logit_bias.empty()) { + return; + } + + ctx->to_search.clear(); + + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) + for (const auto & lb : ctx->logit_bias) { + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); + } + } + + if (ctx->to_search.empty()) { + return; + } + + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; + } + } + } +} + +static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); +} + +static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { + delete (llama_sampler_logit_bias *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, +}; + +struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return new llama_sampler { + /* .iface = */ &llama_sampler_logit_bias_i, + /* .ctx = */ new llama_sampler_logit_bias { + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + }, + }; +} + +// infill + +//#define LM_GGML_DEBUG_SAMPLER_INFILL + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + std::vector buf0; + std::vector buf1; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + +#if defined(LM_GGML_DEBUG_SAMPLER_INFILL) +#define LOG_DBG_CUR LLAMA_LOG_DEBUG +#else +#define LOG_DBG_CUR(...) +#endif + + for (size_t i = 0; i < cur_p->size; ++i) { + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_txt_sum = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } else { + p_txt_sum += cur_p->data[i].p; + } + } + + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; LM_GGML_UNUSED(rat); + + LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (3*p_eog_sum*cur_p->size > p_txt_sum) { + LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + } + + return; + } + + size_t n_combined = 0; LM_GGML_UNUSED(n_combined); + + // combine tokens with common prefix + for (size_t i0 = 0; i0 < cur_p->size; ++i0) { + for (size_t i1 = 0; i1 < cur_p->size; ++i1) { + if (cur_p->data[i0].logit == -INFINITY) { + break; + } + + if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { + continue; + } + + int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + if (len0 < 0) { + ctx->buf0.resize(len0); + len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + assert(len0 > 0); + } + + int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + if (len1 < 0) { + ctx->buf1.resize(len1); + len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + assert(len1 > 0); + } + + // token i0 is a prefix of token i1 + if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { + int dst = i0; + int src = i1; + + // merge into the token with higher probability + if (cur_p->data[i1].p > cur_p->data[i0].p) { + std::swap(dst, src); + } + + cur_p->data[dst].p += cur_p->data[src].p; + cur_p->data[src].logit = -INFINITY; + cur_p->data[src].p = 0.0f; + + n_combined++; + } + } + } + + size_t n_non_eog = 0; + + size_t size_org = cur_p->size; + + float p_sum = 0.0f; + float thold = 0.2f; + + cur_p->size = 0; + + LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + if (!is_eog) { + ++n_non_eog; + } + + p_sum += cur_p->data[i].p; + + // keep this token + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); + + // if no non-EOG tokens are left -> reduce cur_p to single EOT token + if (n_non_eog == 0) { + cur_p->size = 1; + cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); + cur_p->data[0].logit = 1.0f; + + return; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + size_org = cur_p->size; + p_sum = 0.0f; + thold = 1.0/(n_non_eog + 1); + + cur_p->size = 0; + + LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + +#undef LOG_DBG_CUR +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill_impl(*ctx->vocab); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab) { + return new llama_sampler { + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ &vocab, + /* .buf0 = */ std::vector(512), + /* .buf1 = */ std::vector(512), + }, + }; +} + +// utils + +uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { + if (smpl->iface == &llama_sampler_dist_i) { + return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_i) { + return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_v2_i) { + return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_chain_i) { + const auto * ctx = (const llama_sampler_chain *) smpl->ctx; + for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { + const uint32_t seed = llama_sampler_get_seed(*it); + if (seed != LLAMA_DEFAULT_SEED) { + return seed; + } + } + } + + return LLAMA_DEFAULT_SEED; +} + +// perf + +struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) { + struct llama_perf_sampler_data data = {}; + + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + const auto * ctx = (const struct llama_sampler_chain *) chain->ctx; + + data.t_sample_ms = 1e-3 * ctx->t_sample_us; + data.n_sample = std::max(0, ctx->n_sample); + + return data; +} + +void llama_perf_sampler_print(const struct llama_sampler * chain) { + const auto data = llama_perf_sampler(chain); + + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); +} + +void llama_perf_sampler_reset(struct llama_sampler * chain) { + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + auto * ctx = (struct llama_sampler_chain *) chain->ctx; + + ctx->t_sample_us = ctx->n_sample = 0; } diff --git a/cpp/llama-sampling.h b/cpp/llama-sampling.h index f7f8e3e..919f6fd 100644 --- a/cpp/llama-sampling.h +++ b/cpp/llama-sampling.h @@ -1,56 +1,48 @@ #pragma once -#include "llama-impl.h" +// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? -struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} +#include "llama-grammar.h" - std::mt19937 rng; +struct llama_vocab; +struct llama_grammar; - int32_t n_vocab = 0; +// sampler chain - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; +struct llama_sampler_chain { + llama_sampler_chain_params params; - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } -}; + std::vector samplers; + + // timing -// -// internal API -// - -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); - -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); - -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); - -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale); - -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); + mutable int64_t t_sample_us; + + mutable int32_t n_sample; +}; +struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab & vocab, + const char * grammar_str, + const char * grammar_root); + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab); + +struct llama_sampler * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/cpp/llama-vocab.cpp b/cpp/llama-vocab.cpp index 0a70ef3..b013507 100644 --- a/cpp/llama-vocab.cpp +++ b/cpp/llama-vocab.cpp @@ -50,7 +50,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -58,17 +58,17 @@ struct naive_trie { auto res = children.find(c); if (res != children.end()) { return res->second.get_longest_prefix(key, len, offset + 1); - } else { - return std::make_pair(key, offset); } + + return std::make_pair(key, offset); } - struct naive_trie * traverse(const char c) { + const struct naive_trie * traverse(const char c) const { auto res = children.find(c); if (res != children.end()) { return &res->second; - } else { - return NULL; } + + return NULL; } std::map children; bool has_value; @@ -79,6 +79,15 @@ struct naive_trie { // impl // +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; + +llama_vocab::~llama_vocab() { + delete tokenizer; +} + int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { LM_GGML_ASSERT(token_left.find(' ') == std::string::npos); LM_GGML_ASSERT(token_left.find('\n') == std::string::npos); @@ -187,10 +196,15 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars int index = 0; size_t offs = 0; @@ -207,7 +221,7 @@ struct llm_tokenizer_spm { } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { try_add_bigram(i - 1, i); } @@ -271,7 +285,7 @@ struct llm_tokenizer_spm { return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -279,7 +293,6 @@ struct llm_tokenizer_spm { if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); auto token = vocab.token_to_id.find(text); @@ -306,10 +319,11 @@ struct llm_tokenizer_spm { } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -352,8 +366,8 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() { LM_GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: @@ -450,6 +464,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -462,7 +490,14 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab), + bpe_tokenizer(static_cast(vocab.tokenizer)) {} + + static void append(const llama_vocab::id token_id, std::vector & output) { output.push_back(token_id); } @@ -501,12 +536,11 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); @@ -529,7 +563,7 @@ struct llm_tokenizer_bpe { index++; symbols.emplace_back(sym); } - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { add_new_bigram(i - 1, i); } @@ -609,7 +643,6 @@ struct llm_tokenizer_bpe { if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -633,12 +666,10 @@ struct llm_tokenizer_bpe { } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe * bpe_tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -646,15 +677,17 @@ struct llm_tokenizer_bpe { // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; - void tokenize(const std::string & text, std::vector & output) const { - const auto & token_map = vocab.token_to_id; +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + void tokenize(const std::string & text, std::vector & output) { + const auto & token_map = vocab.token_to_id; // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -699,7 +732,7 @@ struct llm_tokenizer_wpm { } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); @@ -751,15 +784,18 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() { if (vocab.precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; @@ -805,6 +841,30 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab), + ugm_tokenizer(static_cast(vocab.tokenizer)) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -843,7 +903,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -873,7 +933,7 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { @@ -905,7 +965,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -918,7 +977,7 @@ struct llm_tokenizer_ugm { normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " "; bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; @@ -963,7 +1022,7 @@ struct llm_tokenizer_ugm { /* * This structure is a view wrapper for XOR-compressed double array (XCDA) * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. - * Eeach bit-packed entry contains: + * Each bit-packed entry contains: * - BASE array value in bits 10-30 * - LCHECK array value in bits 0-7 * - LEAF array value in bit 9 @@ -1000,13 +1059,21 @@ struct llm_tokenizer_ugm { size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1014,8 +1081,8 @@ struct llm_tokenizer_ugm { size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (ugm_tokenizer->xcda_array_size > 0) { + struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1051,52 +1118,162 @@ struct llm_tokenizer_ugm { if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } + } + + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; } } - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; + const llama_vocab & vocab; + const llm_tokenizer_ugm * ugm_tokenizer; +}; - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; +// +// RWKV tokenizer +// - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + output.reserve(escaped.size()); + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } - struct naive_trie user_defined_token_matcher; + continue; + } - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } - float min_score = FLT_MAX; - float max_score = -FLT_MAX; + escaping = false; + continue; + } - float unknown_token_score_penalty = 10.0; - float unknown_token_score; + if (c == '\\') { + escaping = true; + continue; + } + + output.push_back(c); + } + + return output; +} + +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. + + // build trie + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto & token = vocab.id_to_token[id]; + const auto data = llama_unescape_rwkv_token(token.text); + token_matcher.insert((const char *) data.data(), data.size(), id); + } + } struct naive_trie token_matcher; }; +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab), + rwkv_tokenizer(static_cast(*vocab.tokenizer)) {} + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + while (position < text.size()) { + const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.special_unk_id); + position += 1; + continue; + } + + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } + + // add the longest matching token + output.push_back(token_id); + position = token_length; + } + } + +private: + const llama_vocab & vocab; + const llm_tokenizer_rwkv & rwkv_tokenizer; +}; + +void llama_vocab::init_tokenizer() { + switch (type) { + case LLAMA_VOCAB_TYPE_SPM: + tokenizer = new llm_tokenizer_spm(*this); + break; + case LLAMA_VOCAB_TYPE_BPE: + tokenizer = new llm_tokenizer_bpe(*this); + break; + case LLAMA_VOCAB_TYPE_WPM: + tokenizer = new llm_tokenizer_wpm(*this); + break; + case LLAMA_VOCAB_TYPE_UGM: + tokenizer = new llm_tokenizer_ugm(*this); + break; + case LLAMA_VOCAB_TYPE_RWKV: + tokenizer = new llm_tokenizer_rwkv(*this); + break; + default: + LM_GGML_ABORT("unsupported vocab type"); + } +} + // // (de-) tokenize // @@ -1158,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto & raw_text = fragment.raw_text; + const auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -1257,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< } } -std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { +std::vector llama_tokenize_internal( + const llama_vocab & vocab, + std::string raw_text, + bool add_special, + bool parse_special) { + LM_GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + std::vector output; std::forward_list fragment_buffer; @@ -1294,9 +1477,9 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_spm tokenizer(vocab); llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); + llm_tokenizer_spm_session session(vocab); + session.tokenize(raw_text, output); is_prev_special = false; } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); @@ -1318,10 +1501,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe tokenizer(vocab); - + llm_tokenizer_bpe_session session(vocab); + // it calls some other methods that are not exist in llm_tokenizer, + // here just cast it to bpe tokenizer object if (add_special) { - tokenizer.append_bos(output); + session.append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1330,15 +1514,15 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - tokenizer.append(fragment.token, output); + session.append(fragment.token, output); } } if (add_special) { - tokenizer.append_eos(output); - tokenizer.check_double_bos_eos(output); + session.append_eos(output); + session.check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: @@ -1348,7 +1532,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, output.push_back(vocab.special_cls_id); } - llm_tokenizer_wpm tokenizer(vocab); + llm_tokenizer_wpm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1357,7 +1541,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1370,12 +1554,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_UGM: { - llm_tokenizer_ugm tokenizer(vocab); - - if (add_special && vocab.tokenizer_add_bos != 0) { + if (add_special && vocab.tokenizer_add_bos) { LM_GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } + llm_tokenizer_ugm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1383,24 +1566,41 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } - if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (add_special && vocab.tokenizer_add_eos == 1) { + if (add_special && vocab.tokenizer_add_eos) { LM_GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } } break; + case LLAMA_VOCAB_TYPE_RWKV: + { + llm_tokenizer_rwkv_session session(vocab); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + + session.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; case LLAMA_VOCAB_TYPE_NONE: LM_GGML_ABORT("fatal error"); } @@ -1448,11 +1648,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla } bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) { - return token != -1 && ( - token == llama_token_eos_impl(vocab) || - token == llama_token_eot_impl(vocab) || - token == llama_token_eom_impl(vocab) - ); + return token != -1 && vocab.special_eog_ids.count(token) > 0; } bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) { @@ -1467,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) { return vocab.special_eos_id; } +llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { + return vocab.special_eot_id; +} + +llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { + return vocab.special_eom_id; +} + llama_token llama_token_cls_impl(const struct llama_vocab & vocab) { return vocab.special_cls_id; } @@ -1492,33 +1696,49 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) { } llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) { - return vocab.special_prefix_id; + return vocab.special_fim_pre_id; } llama_token llama_token_middle_impl(const struct llama_vocab & vocab) { - return vocab.special_middle_id; + return vocab.special_fim_mid_id; } llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) { - return vocab.special_suffix_id; + return vocab.special_fim_suf_id; } -llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { - return vocab.special_eot_id; +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pre_id; } -llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { - return vocab.special_eom_id; +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_suf_id; +} + +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_mid_id; +} + +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pad_id; +} + +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_rep_id; +} + +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_sep_id; } int32_t llama_tokenize_impl( - const struct llama_vocab & vocab, - const char * text, - int32_t text_len, - llama_token * tokens, - int32_t n_tokens_max, - bool add_special, - bool parse_special) { + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) { auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -1595,11 +1815,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = token_text; llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); - } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } @@ -1610,12 +1832,24 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } break; } + case LLAMA_VOCAB_TYPE_RWKV: { + std::vector result = llama_unescape_rwkv_token(token_text); + + // If we don't have enough space, return an error + if (result.size() > (size_t)length) { + return -(int)result.size(); + } + + memcpy(buf, result.data(), result.size()); + return (int)result.size(); + } default: LM_GGML_ABORT("fatal error"); } @@ -1632,6 +1866,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special) { + LM_GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + int32_t avail = text_len_max; int32_t total = 0; @@ -1730,3 +1966,19 @@ int32_t llama_detokenize_impl( return total <= text_len_max ? total : -total; } + +std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + LM_GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} diff --git a/cpp/llama-vocab.h b/cpp/llama-vocab.h index 6e8f30b..4bb16d2 100644 --- a/cpp/llama-vocab.h +++ b/cpp/llama-vocab.h @@ -6,6 +6,9 @@ #include #include #include +#include + +struct llm_tokenizer; struct llama_vocab { using id = llama_token; @@ -18,6 +21,8 @@ struct llama_vocab { tattr attr; }; + uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -32,37 +37,51 @@ struct llama_vocab { std::map, int> bpe_ranks; // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? id special_bos_id = 1; id special_eos_id = 2; + id special_eot_id = LLAMA_TOKEN_NULL; + id special_eom_id = LLAMA_TOKEN_NULL; id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; - id special_cls_id = -1; - id special_mask_id = -1; - - id linefeed_id = 13; - id special_prefix_id = -1; - id special_suffix_id = -1; - id special_middle_id = -1; - id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token - id special_eom_id = -1; + id special_sep_id = LLAMA_TOKEN_NULL; + id special_pad_id = LLAMA_TOKEN_NULL; + id special_cls_id = LLAMA_TOKEN_NULL; + id special_mask_id = LLAMA_TOKEN_NULL; + + id linefeed_id = 13; + + // fim tokens + id special_fim_pre_id = LLAMA_TOKEN_NULL; + id special_fim_suf_id = LLAMA_TOKEN_NULL; + id special_fim_mid_id = LLAMA_TOKEN_NULL; + id special_fim_pad_id = LLAMA_TOKEN_NULL; + id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator + + // set of all tokens that cause "end of generation" + std::set special_eog_ids; // tokenizer flags - bool tokenizer_add_space_prefix = false; - bool tokenizer_add_bos = false; - bool tokenizer_add_eos = false; - bool tokenizer_ignore_merges = false; - bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces + bool tokenizer_add_space_prefix = false; + bool tokenizer_add_bos = false; + bool tokenizer_add_eos = false; + bool tokenizer_ignore_merges = false; + bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces bool tokenizer_remove_extra_whitespaces = false; bool tokenizer_escape_whitespaces = true; bool tokenizer_treat_whitespace_as_suffix = false; std::vector precompiled_charsmap; + llm_tokenizer * tokenizer = nullptr; + + llama_vocab() = default; + ~llama_vocab(); + int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; -}; -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); + void init_tokenizer(); +}; // // internal API @@ -76,6 +95,7 @@ std::vector llama_tokenize_internal( bool add_special, bool parse_special = false); +// TODO: move the API below as member functions of llama_vocab llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token); @@ -90,19 +110,26 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t llama_token llama_token_bos_impl(const struct llama_vocab & vocab); llama_token llama_token_eos_impl(const struct llama_vocab & vocab); +llama_token llama_token_eot_impl(const struct llama_vocab & vocab); +llama_token llama_token_eom_impl(const struct llama_vocab & vocab); llama_token llama_token_cls_impl(const struct llama_vocab & vocab); llama_token llama_token_sep_impl(const struct llama_vocab & vocab); llama_token llama_token_nl_impl (const struct llama_vocab & vocab); llama_token llama_token_pad_impl(const struct llama_vocab & vocab); -bool llama_add_bos_token_impl(const struct llama_vocab & vocab); -bool llama_add_eos_token_impl(const struct llama_vocab & vocab); - llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); llama_token llama_token_suffix_impl(const struct llama_vocab & vocab); -llama_token llama_token_eot_impl (const struct llama_vocab & vocab); -llama_token llama_token_eom_impl (const struct llama_vocab & vocab); + +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab); + +bool llama_add_bos_token_impl(const struct llama_vocab & vocab); +bool llama_add_eos_token_impl(const struct llama_vocab & vocab); int32_t llama_tokenize_impl( const struct llama_vocab & vocab, @@ -122,6 +149,12 @@ int32_t llama_token_to_piece_impl( int32_t lstrip, bool special); +// check if token0 is contained as a prefix in token1 +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1); + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, @@ -130,3 +163,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special); + +std::string llama_detokenize( + const struct llama_vocab & vocab, + const std::vector & tokens, + bool special); diff --git a/cpp/llama.cpp b/cpp/llama.cpp index 02ad32f..d0c4f5d 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -1,6 +1,5 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-grammar.h" #include "llama-sampling.h" #include "unicode.h" @@ -8,30 +7,7 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" - -#ifdef LM_GGML_USE_RPC -# include "ggml-rpc.h" -#endif - -#ifdef LM_GGML_USE_CUDA -# include "ggml-cuda.h" -#elif defined(LM_GGML_USE_VULKAN) -# include "ggml-vulkan.h" -#elif defined(LM_GGML_USE_SYCL) -# include "ggml-sycl.h" -#elif defined(LM_GGML_USE_KOMPUTE) -# include "ggml-kompute.h" -#elif defined(LM_GGML_USE_CANN) -# include "ggml-cann.h" -#endif - -#ifdef LM_GGML_USE_BLAS -# include "ggml-blas.h" -#endif - -#ifdef LM_GGML_USE_METAL -# include "ggml-metal.h" -#endif +#include "ggml-cpp.h" // TODO: replace with ggml API call #define QK_K 256 @@ -205,6 +181,7 @@ enum llm_arch { LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, LLM_ARCH_MINICPM, + LLM_ARCH_MINICPM3, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, @@ -213,6 +190,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OLMOE, LLM_ARCH_OPENELM, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, @@ -223,6 +201,10 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_RWKV6, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, LLM_ARCH_UNKNOWN, }; @@ -252,6 +234,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, @@ -260,6 +243,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMOE, "olmoe" }, { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, @@ -270,6 +254,10 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -306,6 +294,12 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -320,6 +314,7 @@ enum llm_kv { LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -341,6 +336,8 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_WKV_HEAD_SIZE, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, @@ -350,6 +347,8 @@ enum llm_kv { LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, @@ -362,14 +361,20 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, - LLM_KV_TOKENIZER_PREFIX_ID, - LLM_KV_TOKENIZER_SUFFIX_ID, - LLM_KV_TOKENIZER_MIDDLE_ID, - LLM_KV_TOKENIZER_EOT_ID, - LLM_KV_TOKENIZER_EOM_ID, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + + // deprecated: + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, }; static const std::map LLM_KV_NAMES = { @@ -400,11 +405,17 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, + { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, + { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, + { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, + { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, + { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -419,56 +430,67 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, - - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - - { LLM_KV_SPLIT_NO, "split.no" }, - { LLM_KV_SPLIT_COUNT, "split.count" }, - { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, - - { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, - { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, - { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, - { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, - { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, - { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_SPLIT_NO, "split.no" }, + { LLM_KV_SPLIT_COUNT, "split.count" }, + { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, + + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, }; struct LLM_KV { @@ -529,6 +551,29 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + LLM_TENSOR_CHANNEL_MIX_VALUE, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, @@ -565,9 +610,11 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; -static const std::map> LLM_TENSOR_NAMES = { +static const std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, { @@ -752,6 +799,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -787,6 +836,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, }, }, { @@ -1011,6 +1061,29 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, }, }, + { + LLM_ARCH_MINICPM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_GEMMA, { @@ -1145,6 +1218,26 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_OPENELM, { @@ -1350,6 +1443,94 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_RWKV6, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + }, + }, + { + LLM_ARCH_GRANITE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1377,44 +1558,52 @@ static llm_arch llm_arch_from_string(const std::string & name) { // std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" // std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" // -struct LLM_TN { - LLM_TN(llm_arch arch) : arch(arch) {} - - llm_arch arch; - - std::string operator()(llm_tensor tensor) const { +struct LLM_TN_IMPL { + const llm_arch arch; + const llm_tensor tensor; + const char * const suffix; + const int bid; + const int xid; + + std::string str() const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor); - } - std::string operator()(llm_tensor tensor, const std::string & suffix) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { - return "__missing__"; + std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + + if (suffix != nullptr) { + name += "."; + name += suffix; } - return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix; + + return name; } - std::string operator()(llm_tensor tensor, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { - return "__missing__"; - } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid); + operator std::string() const { + return str(); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { - return "__missing__"; - } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix; + friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) { + return str == tn.str(); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { - return "__missing__"; - } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; + friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) { + return str != tn.str(); + } +}; + +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { + return { arch, tensor, suffix, bid, xid }; + } + + LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const { + return { arch, tensor, nullptr, bid, xid }; } }; @@ -2087,55 +2276,16 @@ static std::string llama_token_to_piece(const struct llama_model * model, llama_ return piece; } -static lm_ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) { - lm_ggml_backend_buffer_type_t buft = nullptr; - -#if defined(LM_GGML_USE_CUDA) - // host buffers should only be used when data is expected to be copied to/from the GPU - if (host_buffer) { - buft = lm_ggml_backend_cuda_host_buffer_type(); - } -#elif defined(LM_GGML_USE_SYCL) - if (host_buffer) { - buft = lm_ggml_backend_sycl_host_buffer_type(); - } -#elif defined(LM_GGML_USE_CPU_HBM) - buft = lm_ggml_backend_cpu_hbm_buffer_type(); -#elif defined(LM_GGML_USE_VULKAN) - if (host_buffer) { - buft = lm_ggml_backend_vk_host_buffer_type(); - } -#endif - - if (buft == nullptr) { - buft = lm_ggml_backend_cpu_buffer_type(); - } - return buft; - - LM_GGML_UNUSED(host_buffer); -} - // // globals // -struct llama_state { - llama_state() { -#ifdef LM_GGML_USE_METAL - lm_ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); -#elif defined(LM_GGML_USE_CUDA) - lm_ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data); -#elif defined(LM_GGML_USE_CANN) - lm_ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data); -#endif - } - - // We save the log callback globally +struct llama_logger_state { lm_ggml_log_callback log_callback = llama_log_callback_default; void * log_callback_user_data = nullptr; }; -static llama_state g_state; +static llama_logger_state g_logger_state; // available llama models enum e_model { @@ -2162,6 +2312,7 @@ enum e_model { MODEL_1B, MODEL_1_3B, MODEL_1_4B, + MODEL_1_6B, MODEL_2B, MODEL_2_8B, MODEL_3B, @@ -2190,6 +2341,7 @@ enum e_model { MODEL_MEDIUM, MODEL_LARGE, MODEL_XL, + MODEL_A1_7B, MODEL_A2_7B, MODEL_8x7B, MODEL_8x22B, @@ -2207,6 +2359,7 @@ struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; + bool swin_norm; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -2239,6 +2392,12 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; @@ -2256,13 +2415,18 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; + // Additional scale factors (Granite/Granite MoE) + float f_residual_scale = 0.0f; + float f_embedding_scale = 0.0f; + float f_attention_scale = 0.0f; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 - llama_token dec_start_token_id = -1; + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -2302,6 +2466,11 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; + if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; + if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; + if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true; + if (this->wkv_head_size != other.wkv_head_size) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; const float EPSILON = 1e-9f; @@ -2313,6 +2482,9 @@ struct llama_hparams { if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; + if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true; + if (!is_float_close(this->f_embedding_scale, other.f_embedding_scale, EPSILON)) return true; + if (!is_float_close(this->f_attention_scale, other.f_attention_scale, EPSILON)) return true; return false; } @@ -2365,15 +2537,25 @@ struct llama_hparams { } uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings - // corresponds to Mamba's conv_states size - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + if (wkv_head_size != 0) { + // for RWKV models + return 2 * n_embd; + } else { + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } else { + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } } }; @@ -2384,8 +2566,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - uint32_t n_threads; // number of threads to use for generation - uint32_t n_threads_batch; // number of threads to use for batch processing + int n_threads; // number of threads to use for generation + int n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -2403,6 +2585,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool no_perf; enum llama_pooling_type pooling_type; @@ -2412,6 +2595,11 @@ struct llama_cparams { // TODO: separate into "llama_layer_enc" and "llama_layer_dec" struct llama_layer { + llama_layer() { + // initialize all pointers to NULL + std::memset(this, 0, sizeof(*this)); + } + // normalization struct lm_ggml_tensor * attn_norm; struct lm_ggml_tensor * attn_norm_b; @@ -2492,9 +2680,9 @@ struct llama_layer { struct lm_ggml_tensor * ffn_up_shexp; // ff bias - struct lm_ggml_tensor * ffn_gate_b = nullptr; - struct lm_ggml_tensor * ffn_down_b = nullptr; // b2 - struct lm_ggml_tensor * ffn_up_b = nullptr; // b3 + struct lm_ggml_tensor * ffn_gate_b; + struct lm_ggml_tensor * ffn_down_b; // b2 + struct lm_ggml_tensor * ffn_up_b; // b3 struct lm_ggml_tensor * ffn_act; // mamba proj @@ -2512,6 +2700,36 @@ struct llama_layer { struct lm_ggml_tensor * ssm_conv1d_b; struct lm_ggml_tensor * ssm_dt_b; + // rwkv + struct lm_ggml_tensor * time_mix_w1; + struct lm_ggml_tensor * time_mix_w2; + struct lm_ggml_tensor * time_mix_lerp_x; + struct lm_ggml_tensor * time_mix_lerp_w; + struct lm_ggml_tensor * time_mix_lerp_k; + struct lm_ggml_tensor * time_mix_lerp_v; + struct lm_ggml_tensor * time_mix_lerp_r; + struct lm_ggml_tensor * time_mix_lerp_g; + + struct lm_ggml_tensor * time_mix_first; + struct lm_ggml_tensor * time_mix_decay; + struct lm_ggml_tensor * time_mix_decay_w1; + struct lm_ggml_tensor * time_mix_decay_w2; + struct lm_ggml_tensor * time_mix_key; + struct lm_ggml_tensor * time_mix_value; + struct lm_ggml_tensor * time_mix_receptance; + struct lm_ggml_tensor * time_mix_gate; + + struct lm_ggml_tensor * time_mix_ln; + struct lm_ggml_tensor * time_mix_ln_b; + struct lm_ggml_tensor * time_mix_output; + + struct lm_ggml_tensor * channel_mix_lerp_k; + struct lm_ggml_tensor * channel_mix_lerp_r; + + struct lm_ggml_tensor * channel_mix_key; + struct lm_ggml_tensor * channel_mix_receptance; + struct lm_ggml_tensor * channel_mix_value; + // long rope factors struct lm_ggml_tensor * rope_long = nullptr; struct lm_ggml_tensor * rope_short = nullptr; @@ -2591,31 +2809,22 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; - std::vector ctxs; - std::vector bufs; + std::vector ctxs; + std::vector bufs; - size_t total_size() const { + size_t total_size() { size_t size = 0; - for (lm_ggml_backend_buffer_t buf : bufs) { - size += lm_ggml_backend_buffer_get_size(buf); + for (auto & buf : bufs) { + size += lm_ggml_backend_buffer_get_size(buf.get()); } return size; } - - ~llama_kv_cache() { - for (struct lm_ggml_context * ctx : ctxs) { - lm_ggml_free(ctx); - } - for (lm_ggml_backend_buffer_t buf : bufs) { - lm_ggml_backend_buffer_free(buf); - } - } }; struct llama_control_vector { std::vector tensors; // per layer - std::vector ctxs; - std::vector bufs; + std::vector ctxs; + std::vector bufs; int32_t layer_start = -1; int32_t layer_end = -1; @@ -2634,15 +2843,6 @@ struct llama_control_vector { } return cur; } - - ~llama_control_vector() { - for (struct lm_ggml_context * ctx : ctxs) { - lm_ggml_free(ctx); - } - for (lm_ggml_backend_buffer_t buf : bufs) { - lm_ggml_backend_buffer_free(buf); - } - } }; struct llama_model { @@ -2655,48 +2855,57 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; - struct lm_ggml_tensor * tok_embd; - struct lm_ggml_tensor * type_embd; - struct lm_ggml_tensor * pos_embd; - struct lm_ggml_tensor * tok_norm; - struct lm_ggml_tensor * tok_norm_b; + struct lm_ggml_tensor * tok_embd = nullptr; + struct lm_ggml_tensor * type_embd = nullptr; + struct lm_ggml_tensor * pos_embd = nullptr; + struct lm_ggml_tensor * tok_norm = nullptr; + struct lm_ggml_tensor * tok_norm_b = nullptr; - struct lm_ggml_tensor * output_norm; - struct lm_ggml_tensor * output_norm_b; - struct lm_ggml_tensor * output; - struct lm_ggml_tensor * output_b; - struct lm_ggml_tensor * output_norm_enc; + struct lm_ggml_tensor * output_norm = nullptr; + struct lm_ggml_tensor * output_norm_b = nullptr; + struct lm_ggml_tensor * output = nullptr; + struct lm_ggml_tensor * output_b = nullptr; + struct lm_ggml_tensor * output_norm_enc = nullptr; + + // classifier + struct lm_ggml_tensor * cls = nullptr; + struct lm_ggml_tensor * cls_b = nullptr; + struct lm_ggml_tensor * cls_out = nullptr; + struct lm_ggml_tensor * cls_out_b = nullptr; std::vector layers; + // gguf metadata + std::unordered_map lm_gguf_kv; + llama_split_mode split_mode; int main_gpu; int n_gpu_layers; std::vector rpc_servers; - // gguf metadata - std::unordered_map lm_gguf_kv; + // list of devices used in this model + std::vector devices; - // layer -> buffer type mapping - struct layer_buft { - layer_buft() : buft_matrix(nullptr), buft(nullptr) {} - layer_buft(lm_ggml_backend_buffer_type_t matrix) : buft_matrix(matrix), buft(matrix) {} - layer_buft(lm_ggml_backend_buffer_type_t matrix, lm_ggml_backend_buffer_type_t other) : buft_matrix(matrix), buft(other) {} - lm_ggml_backend_buffer_type_t buft_matrix; // matrices only - used by split buffers and backends that support only matrix multiplication - lm_ggml_backend_buffer_type_t buft; // everything else - }; + // lists of buffer types used for each layer + using buft_list_t = std::vector>; + buft_list_t cpu_buft_list; + std::map gpu_buft_list; - layer_buft buft_input; - layer_buft buft_output; - std::vector buft_layer; + struct layer_dev { + lm_ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; // contexts where the model tensors metadata is stored - std::vector ctxs; + std::vector ctxs; // the model memory buffers for the tensor data - std::vector bufs; + std::vector bufs; // model memory mapped files llama_mmaps mappings; @@ -2715,18 +2924,7 @@ struct llama_model { std::set lora_adapters; ~llama_model() { - for (struct lm_ggml_context * ctx : ctxs) { - lm_ggml_free(ctx); - } - for (lm_ggml_backend_buffer_t buf : bufs) { -#ifdef LM_GGML_USE_CUDA - if (lm_ggml_backend_buffer_get_type(buf) == lm_ggml_backend_cpu_buffer_type()) { - lm_ggml_backend_cuda_unregister_host_buffer(lm_ggml_backend_buffer_get_base(buf)); - } -#endif - lm_ggml_backend_buffer_free(buf); - } - while (!lora_adapters.empty()) { + while (!lora_adapters.empty()) { llama_lora_adapter_free(*lora_adapters.begin()); } } @@ -2737,9 +2935,6 @@ struct llama_sbatch_seq { llama_seq_id * seq_id; size_t offset; size_t length; - - // helper for smoother batch API transition -- can be deprecated in the future - llama_seq_id all_seq_id; // used if seq_id == NULL }; // sequence-length-aware batch splitting @@ -2834,50 +3029,30 @@ struct llama_sbatch { } else { ubatch.embd = nullptr; } - // from here on, the else branches are deprecated; - // they are helpers for smoother batch API transition - if (batch->pos) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; - } - } else { - // simple split - ubatch.pos = batch->pos + seq.offset; - } - } else { + if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { - llama_pos bi = ids[seq.offset + i]; - ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; - } else { - LM_GGML_ASSERT(seq.n_seq_id == 1); - ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { // simple split if (batch->n_seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id = batch->n_seq_id + seq.offset; - } + ubatch.n_seq_id = batch->n_seq_id + seq.offset; } else { for (size_t i = 0; i < length; ++i) { ubatch.n_seq_id[ubatch.n_seqs + i] = 1; } } if (batch->seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id = batch->seq_id + seq.offset; - } - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; - } + ubatch.seq_id = batch->seq_id + seq.offset; } } if (logits_all) { @@ -2996,7 +3171,6 @@ struct llama_sbatch { s.seq_id = nullptr; s.offset = 0; s.length = n_tokens; - s.all_seq_id = batch.all_seq_id; return; } std::sort(ids.begin(), ids.end(), @@ -3019,7 +3193,7 @@ struct llama_sbatch { if (batch.pos) { return batch.pos[a] < batch.pos[b]; } - // no pos, sort by id (assuming batch.all_pos_1 is positive) + // no pos, sort by id return a < b; } // shared prompts go first @@ -3029,30 +3203,25 @@ struct llama_sbatch { // init seq llama_sbatch_seq * last_seq = nullptr; - if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { - for (size_t i = 0; i < n_tokens; ++i) { - const size_t bi = ids[i]; - const int32_t n_seqs = batch.n_seq_id[bi]; - llama_seq_id * seq_ids = batch.seq_id[bi]; - if (last_seq != nullptr) { - bool same = n_seqs == last_seq->n_seq_id; - for (int32_t j = 0; same && j < n_seqs; ++j) { - if (seq_ids[j] != last_seq->seq_id[j]) { - same = false; - } - } - if (same) { - last_seq->length += 1; - continue; + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; } } - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; - seq.push_back(new_seq); - last_seq = &seq.back(); + if (same) { + last_seq->length += 1; + continue; + } } - } else { - llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; seq.push_back(new_seq); + last_seq = &seq.back(); } // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), @@ -3069,54 +3238,41 @@ struct llama_sbatch { struct llama_context { llama_context(const llama_model & model) : model(model) - , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} - ~llama_context() { - lm_ggml_backend_sched_free(sched); - - for (lm_ggml_backend_t backend : backends) { - lm_ggml_backend_free(backend); - } - - lm_ggml_backend_buffer_free(buf_output); - } - const struct llama_model & model; struct llama_cparams cparams; - struct llama_sampling sampling; struct llama_sbatch sbatch; struct llama_kv_cache kv_self; struct llama_control_vector cvec; std::unordered_map lora_adapters; - std::vector backends; -#ifdef LM_GGML_USE_METAL - lm_ggml_backend_t backend_metal = nullptr; -#endif -#ifdef LM_GGML_USE_BLAS - lm_ggml_backend_t backend_blas = nullptr; -#endif + std::vector backends; + std::vector> set_n_threads_fns; + lm_ggml_backend_t backend_cpu = nullptr; + lm_ggml_threadpool_t threadpool = nullptr; + lm_ggml_threadpool_t threadpool_batch = nullptr; + bool has_evaluated_once = false; - int64_t t_start_us; - int64_t t_load_us; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; - int64_t t_compute_start_us = 0; - int64_t n_queued_tokens = 0; + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls // host buffer for the model output (logits and embeddings) - lm_ggml_backend_buffer_t buf_output = nullptr; + lm_ggml_backend_buffer_ptr buf_output; // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits @@ -3146,7 +3302,7 @@ struct llama_context { // memory buffers used to evaluate the model std::vector buf_compute_meta; - lm_ggml_backend_sched_t sched = nullptr; + lm_ggml_backend_sched_ptr sched; lm_ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; @@ -3180,8 +3336,8 @@ struct llama_lora_adapter { struct llama_model * base_model; // map tensor name to lora_a_b std::unordered_map ab_map; - std::vector ctxs; - std::vector bufs; + std::vector ctxs; + std::vector bufs; float alpha; @@ -3199,12 +3355,6 @@ struct llama_lora_adapter { } ~llama_lora_adapter() { - for (struct lm_ggml_context * ctx : ctxs) { - lm_ggml_free(ctx); - } - for (lm_ggml_backend_buffer_t buf : bufs) { - lm_ggml_backend_buffer_free(buf); - } auto pos = base_model->lora_adapters.find(this); if (pos != base_model->lora_adapters.end()) { base_model->lora_adapters.erase(pos); @@ -3212,120 +3362,45 @@ struct llama_lora_adapter { } }; -static size_t llama_get_device_count(const llama_model & model) { - size_t count = 1; -#if defined(LM_GGML_USE_CUDA) - count = lm_ggml_backend_cuda_get_device_count(); -#elif defined(LM_GGML_USE_SYCL) - count = lm_ggml_backend_sycl_get_device_count(); -#elif defined(LM_GGML_USE_VULKAN) - count = lm_ggml_backend_vk_get_device_count(); -#elif defined(LM_GGML_USE_CANN) - return lm_ggml_backend_cann_get_device_count(); -#endif -#if defined(LM_GGML_USE_RPC) - count += model.rpc_servers.size(); -#endif - return count; - LM_GGML_UNUSED(model); -} - -static lm_ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) { - lm_ggml_backend_buffer_type_t buft = nullptr; - -#if defined(LM_GGML_USE_RPC) - int dev_count = (int)llama_get_device_count(model); - int rpc_count = (int)model.rpc_servers.size(); - if (gpu >= dev_count - rpc_count) { - const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str(); - return lm_ggml_backend_rpc_buffer_type(endpoint); - } -#endif -#if defined(LM_GGML_USE_METAL) - buft = lm_ggml_backend_metal_buffer_type(); -#elif defined(LM_GGML_USE_CUDA) - buft = lm_ggml_backend_cuda_buffer_type(gpu); -#elif defined(LM_GGML_USE_VULKAN) - buft = lm_ggml_backend_vk_buffer_type(gpu); -#elif defined(LM_GGML_USE_SYCL) - buft = lm_ggml_backend_sycl_buffer_type(gpu); -#elif defined(LM_GGML_USE_KOMPUTE) - buft = lm_ggml_backend_kompute_buffer_type(gpu); - if (buft == nullptr) { - LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); - } -#elif defined(LM_GGML_USE_CANN) - buft = lm_ggml_backend_cann_buffer_type(gpu); -#endif - - if (buft == nullptr) { - buft = llama_default_buffer_type_cpu(true); - } - return buft; - LM_GGML_UNUSED(model); - LM_GGML_UNUSED(gpu); +static int llama_get_device_count(const llama_model & model) { + return (int) model.devices.size(); } -static lm_ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) { - lm_ggml_backend_buffer_type_t buft = nullptr; - -#ifdef LM_GGML_USE_CUDA - if (lm_ggml_backend_cuda_get_device_count() > 1) { - buft = lm_ggml_backend_cuda_split_buffer_type(tensor_split); - } -#endif - -#ifdef LM_GGML_USE_SYCL - if (lm_ggml_backend_sycl_get_device_count() > 1) { - buft = lm_ggml_backend_sycl_split_buffer_type(tensor_split); +template +static bool buft_supported(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_dev_t dev, F & fn) { + lm_ggml_init_params params = { + /*.mem_size =*/ lm_ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + lm_ggml_context_ptr ctx { lm_ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); } -#endif - if (buft == nullptr) { - buft = llama_default_buffer_type_offload(model, fallback_gpu); + lm_ggml_backend_buffer_ptr buf { lm_ggml_backend_buft_alloc_buffer(buft, 0) }; + lm_ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < LM_GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } } - return buft; + bool op_supported = lm_ggml_backend_dev_supports_op(dev, op_tensor); - LM_GGML_UNUSED(tensor_split); + return op_supported; } -static size_t llama_get_device_memory(const llama_model & model, int device) { -#if defined(LM_GGML_USE_RPC) - int dev_count = (int)llama_get_device_count(model); - int rpc_count = (int)model.rpc_servers.size(); - if (device >= dev_count - rpc_count) { - size_t total; - size_t free; - const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str(); - lm_ggml_backend_rpc_get_device_memory(endpoint, &free, &total); - return free; +template +static lm_ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + lm_ggml_backend_dev_t cur_dev = cur.first; + lm_ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } } -#endif -#if defined(LM_GGML_USE_CUDA) - size_t total; - size_t free; - lm_ggml_backend_cuda_get_device_memory(device, &free, &total); - return free; -#elif defined(LM_GGML_USE_SYCL) - size_t total; - size_t free; - lm_ggml_backend_sycl_get_device_memory(device, &free, &total); - return free; -#elif defined(LM_GGML_USE_VULKAN) - size_t total; - size_t free; - lm_ggml_backend_vk_get_device_memory(device, &free, &total); - return free; -#elif defined(LM_GGML_USE_CANN) - size_t total; - size_t free; - lm_ggml_backend_cann_get_device_memory(device, &free, &total); - return free; -#else - return 1; -#endif - LM_GGML_UNUSED(model); - LM_GGML_UNUSED(device); + throw std::runtime_error(format("no suitable buffer type found")); } // @@ -3361,33 +3436,26 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - // count used buffer types - std::map buft_layer_count; - if (offload) { - for (int64_t i = 0; i < n_layer; ++i) { - buft_layer_count[model.buft_layer[i].buft]++; - } - } else { - buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer; - } - // create a context for each buffer type std::map ctx_map; - for (auto & it : buft_layer_count) { - int n_layers = it.second; - struct lm_ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*lm_ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - lm_ggml_context * ctx = lm_ggml_init(params); - if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache\n", __func__); - return false; + auto ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct lm_ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*lm_ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + lm_ggml_context * ctx = lm_ggml_init(params); + if (!ctx) { + return nullptr; + } + ctx_map[buft] = ctx; + cache.ctxs.emplace_back(ctx); + return ctx; } - ctx_map[it.first] = ctx; - cache.ctxs.push_back(ctx); - } + return it->second; + }; cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); @@ -3396,7 +3464,28 @@ static bool llama_kv_cache_init( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - struct lm_ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); + const llama_model::buft_list_t * buft_list; + if (offload) { + buft_list = model.dev_layer.at(i).buft_list; + } else { + buft_list = &model.cpu_buft_list; + } + lm_ggml_backend_buffer_type_t buft = select_buft(*buft_list, + [&](lm_ggml_context * ctx) { + lm_ggml_tensor * k = lm_ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { + return k; + } + lm_ggml_tensor * p = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 1); + return lm_ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); + }); + lm_ggml_context * ctx = ctx_for_buft(buft); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); + return false; + } + lm_ggml_tensor * k = lm_ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); lm_ggml_tensor * v = lm_ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); lm_ggml_format_name(k, "cache_k_l%d", i); @@ -3407,8 +3496,9 @@ static bool llama_kv_cache_init( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto it : ctx_map) { - lm_ggml_backend_buffer_type_t buft = it.first; - lm_ggml_context * ctx = it.second; + auto * buft = it.first; + auto * ctx = it.second; + lm_ggml_backend_buffer_t buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); @@ -3416,7 +3506,7 @@ static bool llama_kv_cache_init( } lm_ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - cache.bufs.push_back(buf); + cache.bufs.emplace_back(buf); } return true; @@ -3434,7 +3524,7 @@ static bool llama_kv_cache_find_slot( const uint32_t n_seq_tokens = batch.n_seq_tokens; if (cache.recurrent) { - // For recurrent state architectures (like Mamba), + // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. // A slot should be always be contiguous. @@ -3669,7 +3759,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { cache.used = 0; for (auto & buf : cache.bufs) { - lm_ggml_backend_buffer_clear(buf, 0); + lm_ggml_backend_buffer_clear(buf.get(), 0); } } @@ -3683,7 +3773,7 @@ static bool llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased + // models like Mamba or RWKV can't have a state partially erased if (cache.recurrent) { if (seq_id >= (int64_t) cache.size) { // could be fatal @@ -3697,7 +3787,8 @@ static bool llama_kv_cache_seq_rm( if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { return false; } - if (p0 <= cell.pos && p1 < cell.pos) { + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { tail_id = -1; } } @@ -3819,7 +3910,7 @@ static void llama_kv_cache_seq_add( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be shifted + // for Mamba-like or RWKV models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -3868,7 +3959,7 @@ static void llama_kv_cache_seq_div( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be changed + // for Mamba-like or RWKV models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -4151,21 +4242,38 @@ struct llama_model_loader { lm_ggml_tensor * tensor; - llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct lm_gguf_context * lm_gguf_ctx, lm_ggml_tensor * tensor) : idx(idx), tensor(tensor) { - const int tensor_idx = lm_gguf_find_tensor(lm_gguf_ctx, name); - offs = lm_gguf_get_data_offset(lm_gguf_ctx) + lm_gguf_get_tensor_offset(lm_gguf_ctx, tensor_idx); + llama_tensor_weight(const llama_file * file, uint16_t idx, const struct lm_gguf_context * lm_gguf_ctx, lm_ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = lm_gguf_find_tensor(lm_gguf_ctx, lm_ggml_get_name(tensor)); + if (tensor_idx < 0) { + throw std::runtime_error(format("tensor '%s' not found in the model", lm_ggml_get_name(tensor))); + } + offs = lm_gguf_get_data_offset(lm_gguf_ctx) + lm_gguf_get_tensor_offset(lm_gguf_ctx, tensor_idx); if (offs + lm_ggml_nbytes(tensor) < offs || offs + lm_ggml_nbytes(tensor) > file->size) { - throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", lm_ggml_get_name(tensor))); + } + } + }; + + // custom comparator to sort weights more nicely by layer + struct weight_name_comparer { + bool operator()(const std::string & a, const std::string & b) const { + int a_layer = -1; + int b_layer = -1; + sscanf(a.c_str(), "blk.%d.", &a_layer); + sscanf(b.c_str(), "blk.%d.", &b_layer); + if (a_layer != b_layer) { + return a_layer < b_layer; } + return a < b; } }; - std::vector weights; + std::map weights_map; std::unordered_map kv_overrides; - struct lm_gguf_context * meta = NULL; - std::vector contexts; + lm_gguf_context_ptr meta; + std::vector contexts; std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); @@ -4188,7 +4296,7 @@ struct llama_model_loader { /*.ctx = */ &ctx, }; - meta = lm_gguf_init_from_file(fname.c_str(), params); + meta.reset(lm_gguf_init_from_file(fname.c_str(), params)); if (!meta) { throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); } @@ -4203,7 +4311,14 @@ struct llama_model_loader { // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(files.back().get(), 0, cur->name, meta, cur); + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", lm_ggml_get_name(cur))); + } + n_elements += lm_ggml_nelements(cur); + n_bytes += lm_ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); } uint16_t n_split = 0; get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); @@ -4233,7 +4348,7 @@ struct llama_model_loader { /*.no_alloc = */ true, /*.ctx = */ &ctx, }; - struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(split_path, split_params); + lm_gguf_context_ptr ctx_gguf { lm_gguf_init_from_file(split_path, split_params) }; if (!ctx_gguf) { throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); } @@ -4243,17 +4358,22 @@ struct llama_model_loader { // Save tensors data offset info of the shard. for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur); + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", lm_ggml_get_name(cur))); + } + n_elements += lm_ggml_nelements(cur); + n_bytes += lm_ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); } - - lm_gguf_free(ctx_gguf); } get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); // sanity check { - const int n_tensors_loaded = (int) weights.size(); + const int n_tensors_loaded = (int) weights_map.size(); if (n_tensors != n_tensors_loaded) { throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); } @@ -4262,23 +4382,10 @@ struct llama_model_loader { LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } - n_kv = lm_gguf_get_n_kv(meta); - n_tensors = weights.size(); - - fver = (enum llama_fver) lm_gguf_get_version(meta); + n_kv = lm_gguf_get_n_kv(meta.get()); + n_tensors = weights_map.size(); - std::set tensor_names; - for (auto & w : weights) { - n_elements += lm_ggml_nelements(w.tensor); - n_bytes += lm_ggml_nbytes(w.tensor); - // make sure there is no duplicated tensor names - const std::string name(w.tensor->name); - auto found = tensor_names.find(name); - if (found != tensor_names.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); - } - tensor_names.insert(name); - } + fver = (enum llama_fver) lm_gguf_get_version(meta.get()); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); @@ -4291,8 +4398,10 @@ struct llama_model_loader { uint32_t n_type_max = 0; enum lm_ggml_type type_max = LM_GGML_TYPE_F32; - for (int i = 0; i < n_tensors; i++) { - const lm_ggml_tensor * tensor = weights.at(i).tensor; + for (const auto & it : weights_map) { + const llama_tensor_weight & w = it.second; + const lm_ggml_tensor * tensor = w.tensor; + enum lm_ggml_type type = tensor->type; n_type[type]++; @@ -4303,8 +4412,8 @@ struct llama_model_loader { } if (trace > 0) { - const uint16_t sid = weights.at(i).idx; - LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, lm_ggml_get_name(tensor), lm_ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + const uint16_t sid = w.idx; + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, lm_ggml_get_name(tensor), lm_ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); } } @@ -4322,6 +4431,8 @@ struct llama_model_loader { case LM_GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; case LM_GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; case LM_GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case LM_GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; + case LM_GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; case LM_GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case LM_GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; case LM_GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; @@ -4345,23 +4456,23 @@ struct llama_model_loader { ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); { - const int kid = lm_gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV + const int kid = lm_gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV if (kid >= 0) { - ftype = (llama_ftype) lm_gguf_get_val_u32(meta, kid); + ftype = (llama_ftype) lm_gguf_get_val_u32(meta.get(), kid); } } LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); for (int i = 0; i < n_kv; i++) { - const char * name = lm_gguf_get_key(meta, i); - const enum lm_gguf_type type = lm_gguf_get_kv_type(meta, i); + const char * name = lm_gguf_get_key(meta.get(), i); + const enum lm_gguf_type type = lm_gguf_get_kv_type(meta.get(), i); const std::string type_name = type == LM_GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", lm_gguf_type_name(type), lm_gguf_type_name(lm_gguf_get_arr_type(meta, i)), lm_gguf_get_arr_n(meta, i)) + ? format("%s[%s,%d]", lm_gguf_type_name(type), lm_gguf_type_name(lm_gguf_get_arr_type(meta.get(), i)), lm_gguf_get_arr_n(meta.get(), i)) : lm_gguf_type_name(type); - std::string value = lm_gguf_kv_to_str(meta, i); + std::string value = lm_gguf_kv_to_str(meta.get(), i); const size_t MAX_VALUE_LEN = 40; if (value.size() > MAX_VALUE_LEN) { value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); @@ -4390,19 +4501,10 @@ struct llama_model_loader { this->check_tensors = check_tensors; } - ~llama_model_loader() { - if (meta) { - lm_gguf_free(meta); - } - for (auto * ctx : contexts) { - lm_ggml_free(ctx); - } - } - template typename std::enable_if::value, bool>::type get_arr_n(const std::string & key, T & result, const bool required = true) { - const int kid = lm_gguf_find_key(meta, key.c_str()); + const int kid = lm_gguf_find_key(meta.get(), key.c_str()); if (kid < 0) { if (required) { @@ -4412,7 +4514,7 @@ struct llama_model_loader { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta, kid); + GGUFMeta::GKV::get_kv(meta.get(), kid); result = arr_info.length; @@ -4427,9 +4529,9 @@ struct llama_model_loader { template bool get_arr(const std::string & key, std::vector & result, const bool required = true) { - const int kid = lm_gguf_find_key(meta, key.c_str()); + const int kid = lm_gguf_find_key(meta.get(), key.c_str()); - if (kid < 0 || lm_gguf_get_kv_type(meta, kid) != LM_GGUF_TYPE_ARRAY) { + if (kid < 0 || lm_gguf_get_kv_type(meta.get(), kid) != LM_GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -4437,7 +4539,7 @@ struct llama_model_loader { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta, kid); + GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { case LM_GGUF_TYPE_FLOAT32: LM_GGML_ASSERT((std::is_same::value)); break; @@ -4456,9 +4558,9 @@ struct llama_model_loader { template bool get_arr(const std::string & key, std::array & result, const bool required = true) { - const int kid = lm_gguf_find_key(meta, key.c_str()); + const int kid = lm_gguf_find_key(meta.get(), key.c_str()); - if (kid < 0 || lm_gguf_get_kv_type(meta, kid) != LM_GGUF_TYPE_ARRAY) { + if (kid < 0 || lm_gguf_get_kv_type(meta.get(), kid) != LM_GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -4466,7 +4568,7 @@ struct llama_model_loader { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta, kid); + GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { case LM_GGUF_TYPE_FLOAT32: LM_GGML_ASSERT((std::is_same::value)); break; @@ -4498,7 +4600,7 @@ struct llama_model_loader { const struct llama_model_kv_override * override = it != kv_overrides.end() ? &it->second : nullptr; - const bool found = GGUFMeta::GKV::set(meta, key, result, override); + const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); if (required && !found) { throw std::runtime_error(format("key not found in model: %s", key.c_str())); @@ -4515,7 +4617,7 @@ struct llama_model_loader { // get array of n <= N_MAX elements, or a single element repeated n times template bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) { - const int kid = lm_gguf_find_key(meta, key.c_str()); + const int kid = lm_gguf_find_key(meta.get(), key.c_str()); if (kid < 0) { if (required) { @@ -4528,9 +4630,9 @@ struct llama_model_loader { throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); } - if (lm_gguf_get_kv_type(meta, kid) == LM_GGUF_TYPE_ARRAY) { + if (lm_gguf_get_kv_type(meta.get(), kid) == LM_GGUF_TYPE_ARRAY) { struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta, kid); + GGUFMeta::GKV::get_kv(meta.get(), kid); if (n != arr_info.length) { throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); @@ -4566,21 +4668,13 @@ struct llama_model_loader { return llm_kv.arch; } - const char * get_tensor_name(int i) const { - return weights.at(i).tensor->name; - } - const llama_tensor_weight * get_weight(const char * name) const { - for (const auto & weight : weights) { - if (strcmp(name, weight.tensor->name) == 0) { - return &weight; - } + auto pos = weights_map.find(name); + if (pos != weights_map.end()) { + return &pos->second; } - return nullptr; - } - const llama_tensor_weight * get_weight(int i) const { - return get_weight(get_tensor_name(i)); + return nullptr; } const llama_tensor_weight & require_weight(const char * name) const { @@ -4599,28 +4693,11 @@ struct llama_model_loader { return weight->tensor; } - struct lm_ggml_tensor * require_tensor_meta(const char * name) const { - struct lm_ggml_tensor * tensor = get_tensor_meta(name); + struct lm_ggml_tensor * require_tensor_meta(const std::string & name) const { + struct lm_ggml_tensor * tensor = get_tensor_meta(name.c_str()); if (!tensor) { - throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); - } - return tensor; - } - - struct lm_ggml_tensor * get_tensor_meta(int i) const { - return get_tensor_meta(get_tensor_name(i)); - } - - struct lm_ggml_tensor * create_tensor_for(struct lm_ggml_context * ctx, const struct lm_ggml_tensor * cur, bool duplicated) { - struct lm_ggml_tensor * tensor = lm_ggml_dup_tensor(ctx, cur); - lm_ggml_set_name(tensor, lm_ggml_get_name(cur)); - - if (duplicated) { - size_data += lm_ggml_nbytes(cur); - } else { - n_created++; + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); } - return tensor; } @@ -4657,17 +4734,29 @@ struct llama_model_loader { static const int TENSOR_NOT_REQUIRED = 1; static const int TENSOR_DUPLICATED = 2; - struct lm_ggml_tensor * create_tensor(struct lm_ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + struct lm_ggml_tensor * create_tensor(struct lm_ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0) { const struct lm_ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); + bool duplicated = flags & TENSOR_DUPLICATED; + + struct lm_ggml_tensor * tensor = lm_ggml_dup_tensor(ctx, cur); + lm_ggml_set_name(tensor, lm_ggml_get_name(cur)); + + if (duplicated) { + size_data += lm_ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; + } - struct lm_ggml_tensor * create_tensor_as_view(struct lm_ggml_context * ctx, struct lm_ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) { + struct lm_ggml_tensor * create_tensor_as_view(struct lm_ggml_context * ctx, struct lm_ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true) { const struct lm_ggml_tensor * cur = check_tensor_dims(name, ne, required); if (cur == NULL) { @@ -4680,7 +4769,7 @@ struct llama_model_loader { std::array dims; for (size_t i = 0; i < LM_GGML_MAX_DIMS; ++i) { - dims[i] = i < ne.size() ? ne[i] : 1; + dims[i] = i < ne.size() ? ne.begin()[i] : 1; } struct lm_ggml_tensor * tensor = lm_ggml_view_4d(ctx, base, @@ -4718,8 +4807,8 @@ struct llama_model_loader { } // compute the total size of all tensors for progress reporting - for (auto & w : weights) { - size_data += lm_ggml_nbytes(w.tensor); + for (const auto & it : weights_map) { + size_data += lm_ggml_nbytes(it.second.tensor); } } @@ -4731,19 +4820,12 @@ struct llama_model_loader { *last = 0; *addr = mapping->addr; for (lm_ggml_tensor * tensor = lm_ggml_get_first_tensor(ctx); tensor; tensor = lm_ggml_get_next_tensor(ctx, tensor)) { - try { - const auto * weight = get_weight(lm_ggml_get_name(tensor)); - if (!weight) { - continue; - } - if (weight->idx != idx) { - continue; - } - *first = std::min(*first, weight->offs); - *last = std::max(*last, weight->offs + lm_ggml_nbytes(tensor)); - } catch(...) { - // the tensor is not in the model + const auto * weight = get_weight(lm_ggml_get_name(tensor)); + if (!weight || weight->idx != idx) { + continue; } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + lm_ggml_nbytes(tensor)); } } @@ -4778,7 +4860,7 @@ struct llama_model_loader { // Returns false if cancelled by progress_callback bool load_all_data( struct lm_ggml_context * ctx, - llama_buf_map & bufs_mmap, + llama_buf_map & bufs, llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -4787,43 +4869,94 @@ struct llama_model_loader { std::vector> read_buf; std::vector>> validation_result; -#if defined(LM_GGML_USE_CUDA) // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. // NVMe raid configurations might require more / larger buffers. constexpr size_t n_buffers = 4; constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB std::vector host_buffers; - std::vector host_ptrs; std::vector events; + std::vector host_ptrs; size_t buffer_idx = 0; // buffer to use for async loads - - lm_ggml_backend_t cuda_backend = nullptr; - if (!use_mmap && !check_tensors) { + lm_ggml_backend_t upload_backend = [&](const char * func) -> lm_ggml_backend_t { + if (use_mmap || check_tensors) { + return nullptr; + } // When not using mmaped io use async uploads from pinned memory to GPU memory. - // First determine if the CUDA backend is active, and if so, determine the device ID. - lm_ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; - if (buf) { - lm_ggml_backend_buffer_type_t buffer_type = lm_ggml_backend_buffer_get_type(buf); - for (int i = 0; i < lm_ggml_backend_cuda_get_device_count(); ++i) { - auto * cuda_buffer_type = lm_ggml_backend_cuda_buffer_type(i); - if (buffer_type == cuda_buffer_type) { - cuda_backend = lm_ggml_backend_cuda_init(i); - break; - } - } + // First determine if the backend supports the necessary features for async uploads. + auto * buf = bufs.count(0) ? bufs.at(0) : nullptr; + if (!buf) { + LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func); + return nullptr; + } + + auto * buft = lm_ggml_backend_buffer_get_type(buf); + auto * dev = lm_ggml_backend_buft_get_device(buft); + if (!dev) { + LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func, + lm_ggml_backend_buft_name(buft)); + return nullptr; + } + + if (buft != lm_ggml_backend_dev_buffer_type(dev)) { + LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func, + lm_ggml_backend_buft_name(buft), lm_ggml_backend_dev_name(dev)); + return nullptr; } - // If the cuda backend is active create pinned memory buffers and events for synchronisation. - if (cuda_backend) { - for (size_t idx = 0; idx < n_buffers; ++idx) { - host_buffers.emplace_back(lm_ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); - host_ptrs.emplace_back(lm_ggml_backend_buffer_get_base(host_buffers[idx])); - events.emplace_back(lm_ggml_backend_event_new(cuda_backend)); + lm_ggml_backend_dev_props props; + lm_ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) { + LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func, + lm_ggml_backend_dev_name(dev)); + return nullptr; + } + + auto * host_buft = lm_ggml_backend_dev_host_buffer_type(dev); + if (!host_buft) { + LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func, + lm_ggml_backend_dev_name(dev)); + return nullptr; + } + + // If the backend is supported, create pinned memory buffers and events for synchronisation. + for (size_t idx = 0; idx < n_buffers; ++idx) { + auto * buf = lm_ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { + LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, + lm_ggml_backend_dev_name(dev)); + return nullptr; + } + + host_buffers.emplace_back(buf); + host_ptrs.emplace_back(lm_ggml_backend_buffer_get_base(buf)); + + auto * event = lm_ggml_backend_event_new(dev); + if (!event) { + LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func, + lm_ggml_backend_dev_name(dev)); + return nullptr; } + + events.emplace_back(event); } + + lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr); + if (!backend) { + LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func, + lm_ggml_backend_dev_name(dev)); + return nullptr; + } + + return backend; + }(__func__); + + if (upload_backend) { + LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__, + lm_ggml_backend_dev_name(lm_ggml_backend_get_device(upload_backend)), + lm_ggml_backend_buft_name(lm_ggml_backend_buffer_get_type(bufs.at(0))), + lm_ggml_backend_name(upload_backend)); } -#endif for (struct lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur != NULL; cur = lm_ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(lm_ggml_get_name(cur)); @@ -4843,8 +4976,8 @@ struct llama_model_loader { if (use_mmap) { const auto & mapping = mappings.at(weight->idx); lm_ggml_backend_buffer_t buf_mmap = nullptr; - if (bufs_mmap.count(weight->idx)) { - buf_mmap = bufs_mmap.at(weight->idx); + if (bufs.count(weight->idx)) { + buf_mmap = bufs.at(weight->idx); } uint8_t * data = (uint8_t *) mapping->addr + weight->offs; @@ -4869,7 +5002,6 @@ struct llama_model_loader { lm_ggml_backend_tensor_set(cur, data, 0, n_size); } } else { - LM_GGML_ASSERT(weight->idx < files.size()); const auto & file = files.at(weight->idx); if (lm_ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); @@ -4880,9 +5012,8 @@ struct llama_model_loader { })); } } else { -#if defined(LM_GGML_USE_CUDA) - // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. - if (cuda_backend) { + // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (upload_backend) { file->seek(weight->offs, SEEK_SET); size_t bytes_read = 0; @@ -4892,17 +5023,14 @@ struct llama_model_loader { lm_ggml_backend_event_synchronize(events[buffer_idx]); file->read_raw(host_ptrs[buffer_idx], read_iteration); - lm_ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); - lm_ggml_backend_event_record(events[buffer_idx]); + lm_ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + lm_ggml_backend_event_record(events[buffer_idx], upload_backend); bytes_read += read_iteration; ++buffer_idx; buffer_idx %= n_buffers; } - } - else -#endif - { + } else { read_buf.resize(n_size); file->seek(weight->offs, SEEK_SET); file->read_raw(read_buf.data(), n_size); @@ -4917,17 +5045,15 @@ struct llama_model_loader { size_done += n_size; } -#if defined(LM_GGML_USE_CUDA) - // free temporary resources used for async cuda uploads - if (cuda_backend) { - for (size_t idx = 0; idx < n_buffers;++idx) { - lm_ggml_backend_event_synchronize(events[idx]); - lm_ggml_backend_event_free(events[idx]); - lm_ggml_backend_buffer_free(host_buffers[idx]); - } - lm_ggml_backend_free(cuda_backend); + // free temporary resources used for async uploads + for (auto * event : events) { + lm_ggml_backend_event_synchronize(event); + lm_ggml_backend_event_free(event); } -#endif + for (auto * buf : host_buffers) { + lm_ggml_backend_buffer_free(buf); + } + lm_ggml_backend_free(upload_backend); // check validation results bool validation_failed = false; @@ -4966,6 +5092,57 @@ struct llama_model_loader { } }; +// temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; +struct llama_batch_allocr { + std::array seq_id_0 = {batch_default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) { + batch = in_batch; + LM_GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos = -1; + for (const auto & cell : ctx.kv_self.cells) { + if (cell.has_seq_id(batch_default_seq_id)) { + last_pos = std::max(last_pos, cell.pos); + } + } + last_pos++; // next position + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; @@ -5015,6 +5192,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; @@ -5059,6 +5238,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_1B: return "1B"; case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; + case MODEL_1_6B: return "1.6B"; case MODEL_2B: return "2B"; case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; @@ -5087,6 +5267,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; case MODEL_XL: return "1.5B"; + case MODEL_A1_7B: return "A1.7B"; case MODEL_A2_7B: return "A2.7B"; case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; @@ -5105,6 +5286,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; default: return "unknown"; } } @@ -5120,7 +5302,7 @@ static void llm_load_hparams( llama_model_loader & ml, llama_model & model) { auto & hparams = model.hparams; - const lm_gguf_context * ctx = ml.meta; + const lm_gguf_context * ctx = ml.meta.get(); // get metadata as string for (int i = 0; i < lm_gguf_get_n_kv(ctx); i++) { @@ -5238,8 +5420,10 @@ static void llm_load_hparams( } } else { switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; + case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; case 36: model.type = e_model::MODEL_8B; break; // granite @@ -5260,6 +5444,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MINICPM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: model.type = e_model::MODEL_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_GROK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5341,11 +5536,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base } } break; @@ -5625,6 +5820,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OLMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_A1_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_OPENELM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5801,6 +6004,54 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_RWKV6: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 61: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5818,7 +6069,7 @@ static void llm_load_vocab( llama_model & model) { auto & vocab = model.vocab; - struct lm_gguf_context * ctx = ml.meta; + struct lm_gguf_context * ctx = ml.meta.get(); const auto kv = LLM_KV(model.arch); @@ -5834,33 +6085,40 @@ static void llm_load_vocab( vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; - vocab.linefeed_id = -1; - + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + vocab.linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { + vocab.n_vocab = 0; + LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab); + } return; - } else if (tokenizer_model == "llama") { + } + + if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens vocab.special_bos_id = 1; vocab.special_eos_id = 2; vocab.special_unk_id = 0; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; } else if (tokenizer_model == "bert") { vocab.type = LLAMA_VOCAB_TYPE_WPM; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; vocab.special_unk_id = 100; vocab.special_sep_id = 102; vocab.special_pad_id = 0; @@ -5896,22 +6154,22 @@ static void llm_load_vocab( // default special tokens vocab.special_bos_id = 11; vocab.special_eos_id = 11; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; } else if (tokenizer_model == "t5") { vocab.type = LLAMA_VOCAB_TYPE_UGM; // default special tokens - vocab.special_bos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; vocab.special_eos_id = 1; vocab.special_unk_id = 2; - vocab.special_sep_id = -1; + vocab.special_sep_id = LLAMA_TOKEN_NULL; vocab.special_pad_id = 0; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; const int precompiled_charsmap_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { @@ -5930,6 +6188,15 @@ static void llm_load_vocab( } #endif } + } else if (tokenizer_model == "rwkv") { + vocab.type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -5978,6 +6245,7 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code") { @@ -6012,7 +6280,7 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "chatglm-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; - vocab.special_bos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; } else if ( tokenizer_pre == "viking") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; @@ -6042,6 +6310,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "exaone") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6061,6 +6334,12 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = false; vocab.tokenizer_add_eos = true; + } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = false; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } @@ -6088,11 +6367,17 @@ static void llm_load_vocab( const uint32_t n_vocab = lm_gguf_get_arr_n(ctx, token_idx); + vocab.n_vocab = n_vocab; vocab.id_to_token.resize(n_vocab); for (uint32_t i = 0; i < n_vocab; i++) { std::string word = lm_gguf_get_arr_str(ctx, token_idx, i); - LM_GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + //LM_GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); @@ -6117,46 +6402,10 @@ static void llm_load_vocab( } LM_GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); + vocab.init_tokenizer(); + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - // For Fill-In-the-Middle (FIM)/infill models which where converted - // prior to support of FIM special tokens in GGUF, the following - // will allow those models to continue to work. The general names - // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and - // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once - // new versions of these models have been published. - std::string gen_name; - ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); - - std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), - [](unsigned char c){ return std::tolower(c); }); - - if (gen_name.find("code") != std::string::npos) { - if (model.arch == LLM_ARCH_LLAMA - && 32010 < vocab.id_to_token.size() - && vocab.id_to_token[32007].text.find("
") != std::string::npos
-              && vocab.id_to_token[32008].text.find("") != std::string::npos
-              && vocab.id_to_token[32009].text.find("") != std::string::npos
-              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
-                vocab.special_prefix_id = 32007;
-                vocab.special_suffix_id = 32008;
-                vocab.special_middle_id = 32009;
-                vocab.special_eot_id    = 32010;
-            } else if (model.arch == LLM_ARCH_GEMMA
-              && 107 < vocab.id_to_token.size()
-              && vocab.id_to_token[67].text == "<|fim_prefix|>"
-              && vocab.id_to_token[69].text == "<|fim_suffix|>"
-              && vocab.id_to_token[68].text == "<|fim_middle|>"
-              && vocab.id_to_token[107].text == "") {
-                vocab.special_prefix_id = 67;
-                vocab.special_suffix_id = 69;
-                vocab.special_middle_id = 68;
-                // TODO: this is not EOT, it is "file separator" token, needs fix
-                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
-                //vocab.special_eot_id    = 70;
-                vocab.special_eot_id    = 107;
-            }
-        }
         try {
             vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
         } catch (const std::exception & e) {
@@ -6165,27 +6414,45 @@ static void llm_load_vocab(
         }
     } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
         vocab.linefeed_id = vocab.special_pad_id;
-    } else {
-        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
+    } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector ids = llama_tokenize_internal(vocab, "\n", false);
         LM_GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
         vocab.linefeed_id = ids[0];
+    } else {
+        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
+
+        //LM_GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        if (ids.empty()) {
+            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
+            vocab.linefeed_id = vocab.special_pad_id;
+        } else {
+            vocab.linefeed_id = ids[0];
+        }
     }
 
     // special tokens
     {
         const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
-            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
-            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
-            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
-            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
-            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
-            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
-            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
-            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
+            { LLM_KV_TOKENIZER_BOS_ID,     vocab.special_bos_id     },
+            { LLM_KV_TOKENIZER_EOS_ID,     vocab.special_eos_id     },
+            { LLM_KV_TOKENIZER_EOT_ID,     vocab.special_eot_id     },
+            { LLM_KV_TOKENIZER_EOM_ID,     vocab.special_eom_id     },
+            { LLM_KV_TOKENIZER_UNK_ID,     vocab.special_unk_id     },
+            { LLM_KV_TOKENIZER_SEP_ID,     vocab.special_sep_id     },
+            { LLM_KV_TOKENIZER_PAD_ID,     vocab.special_pad_id     },
+            { LLM_KV_TOKENIZER_CLS_ID,     vocab.special_cls_id     },
+            { LLM_KV_TOKENIZER_MASK_ID,    vocab.special_mask_id    },
+            { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id },
+            { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id },
+            { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id },
+            { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id },
+            { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id },
+            { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id },
+
+            // deprecated
+            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id },
         };
 
         for (const auto & it : special_token_types) {
@@ -6216,68 +6483,230 @@ static void llm_load_vocab(
             }
         }
 
-        // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
-        //       for now, we apply this workaround to find the EOT token based on its text
-        if (vocab.special_eot_id == -1) {
-            for (const auto & t : vocab.token_to_id) {
-                if (
-                        // TODO: gemma "" is exported as a normal token, so the following check does not work
-                        //       need to fix convert script
-                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "" ||
-                         t.first == "<|endoftext|>"
-                        )
+        // auto-detect special tokens by text
+        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
+        //       for now, we apply this workaround to find the tokens based on their text
+
+        for (const auto & t : vocab.token_to_id) {
+            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
+            if (vocab.special_eot_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == ""
+                        || t.first == "<|endoftext|>"
+                        || t.first == ""
+                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
                    ) {
                     vocab.special_eot_id = t.second;
-                    break;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                 }
             }
-        }
 
-        // find EOM token: "<|eom_id|>"
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
-        //       for now, we apply this workaround to find the EOM token based on its text
-        if (vocab.special_eom_id == -1) {
-            const auto & t = vocab.token_to_id.find("<|eom_id|>");
-            if (t != vocab.token_to_id.end()) {
-                vocab.special_eom_id = t->second;
+            // find EOM token: "<|eom_id|>"
+            if (vocab.special_eom_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eom_id|>"
+                        ) {
+                    vocab.special_eom_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
             }
-        }
-    }
 
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
+            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    vocab.special_fim_pre_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
             }
-        }
 
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_suf_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
             }
-        );
-
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
 
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_mid_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
 
-        std::vector cache_token_to_piece(n_vocab);
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_pad_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
 
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_rep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
 
-            size_cache += cache_token_to_piece[id].size();
+            // find FIM_SEP token: "<|file_sep|>"
+            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    vocab.special_fim_sep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        vocab.special_eog_ids.clear();
+
+        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
+        }
+
+        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
+        }
+
+        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
+        }
+
+        for (const auto & t : vocab.token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                vocab.special_eog_ids.insert(t.second);
+                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                vocab.cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
+            [&] (const llama_vocab::id a, const llama_vocab::id b) {
+                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache_token_to_piece(n_vocab);
+
+        for (uint32_t id = 0; id < n_vocab; ++id) {
+            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
+
+            size_cache += cache_token_to_piece[id].size();
         }
 
         std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
@@ -6439,19 +6868,28 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
     // special tokens
-    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
-    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
-    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
-    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
-    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
+    if (vocab.special_bos_id  != -1)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
+    if (vocab.special_eos_id  != -1)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
+    if (vocab.special_eot_id  != -1)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
+    if (vocab.special_eom_id  != -1)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
+    if (vocab.special_unk_id  != -1)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
+    if (vocab.special_sep_id  != -1)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
+    if (vocab.special_pad_id  != -1)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
+    if (vocab.special_cls_id  != -1)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
+    if (vocab.special_mask_id != -1)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+
+    if (vocab.linefeed_id != -1)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
+
+    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
+    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
+    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
+    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
+    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
+    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : vocab.special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
+    }
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
 
@@ -6469,6 +6907,363 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
+
+    if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
+        LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
+        LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
+        LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
+    }
+}
+
+enum llm_tensor_layer {
+    LLM_TENSOR_LAYER_INPUT,
+    LLM_TENSOR_LAYER_REPEATING,
+    LLM_TENSOR_LAYER_OUTPUT,
+};
+
+struct llm_tensor_info {
+    llm_tensor_layer layer;
+    lm_ggml_op op;
+};
+
+static const std::map llm_tensor_info_mapping = {
+    {LLM_TENSOR_TOKEN_EMBD,                 {LLM_TENSOR_LAYER_INPUT, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_POS_EMBD,                   {LLM_TENSOR_LAYER_INPUT, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_TOKEN_EMBD_NORM,            {LLM_TENSOR_LAYER_INPUT, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_TOKEN_TYPES,                {LLM_TENSOR_LAYER_INPUT, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_OUTPUT,                     {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CLS,                        {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CLS_OUT,                    {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_OUTPUT_NORM,                {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_DEC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ENC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ROPE_FREQS,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ROPE}},
+    {LLM_TENSOR_ROPE_FACTORS_LONG,          {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ROPE}},
+    {LLM_TENSOR_ROPE_FACTORS_SHORT,         {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ROPE}},
+    {LLM_TENSOR_ATTN_Q,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_K,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_DOWN_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_UP_SHEXP,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_Q_A,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_Q,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_K,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_DOWN_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_UP_SHEXP,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_Q_A,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_V,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_ATTN_OUT,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_CROSS_ATTN_Q,           {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_CROSS_ATTN_K,           {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_CROSS_ATTN_V,           {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_CROSS_ATTN_OUT,         {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_FFN_GATE,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_FFN_DOWN,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DEC_FFN_UP,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_ATTN_V,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_ATTN_OUT,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_FFN_GATE,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_FFN_DOWN,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ENC_FFN_UP,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE_INP_SHEXP,         {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_GATE_INP,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_SSM_IN,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_SSM_X,                      {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_SSM_DT,                     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_SSM_OUT,                    {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_W1,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_W2,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_DECAY_W1,          {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_DECAY_W2,          {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_KEY,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_VALUE,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_RECEPTANCE,        {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_GATE,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_OUTPUT,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CHANNEL_MIX_KEY,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,     {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CHANNEL_MIX_VALUE,          {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_FFN_ACT,                    {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_DIV}},
+    {LLM_TENSOR_SSM_CONV1D,                 {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_SSM_CONV}},
+    {LLM_TENSOR_SSM_A,                      {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_SSM_SCAN}},
+    {LLM_TENSOR_SSM_D,                      {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_LERP_X,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_LN,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_CHANNEL_MIX_LERP_K,         {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_CHANNEL_MIX_LERP_R,         {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_LERP_W,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_LERP_K,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_LERP_V,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_LERP_R,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_LERP_G,            {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_DECAY,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_FIRST,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_RWKV_WKV}},
+    {LLM_TENSOR_ATTN_NORM,                  {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_NORM_2,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_OUT_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_POST_NORM,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_FFN_NORM,                   {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_FFN_POST_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_FFN_NORM_EXPS,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_Q_NORM,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_K_NORM,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_LAYER_OUT_NORM,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_Q_A_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_KV_A_NORM,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ATTN_SUB_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_FFN_SUB_NORM,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_DEC_ATTN_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_DEC_CROSS_ATTN_NORM,        {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_DEC_FFN_NORM,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ENC_ATTN_NORM,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_ENC_FFN_NORM,               {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}},
+    {LLM_TENSOR_DEC_ATTN_REL_B,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_ENC_ATTN_REL_B,             {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_FFN_DOWN_EXPS,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT_ID}},
+    // this tensor is loaded for T5, but never used
+    {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_NONE}},
+};
+
+// checks if the weight tensor can be used with the specified buffer type and device
+static bool weight_buft_supported(const llama_hparams & hparams, lm_ggml_tensor * w, lm_ggml_op op, lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_dev_t dev) {
+    LM_GGML_ASSERT(w != nullptr);
+
+    if (op == LM_GGML_OP_NONE) {
+        return true;
+    }
+
+    lm_ggml_init_params params = {
+        /*.mem_size   =*/ lm_ggml_tensor_overhead()*8,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+    lm_ggml_context_ptr ctx_ptr { lm_ggml_init(params) };
+    if (!ctx_ptr) {
+        throw std::runtime_error(format("failed to create ggml context"));
+    }
+    lm_ggml_context * ctx = ctx_ptr.get();
+
+    lm_ggml_tensor * op_tensor = nullptr;
+
+    switch (op) {
+        case LM_GGML_OP_GET_ROWS:
+            {
+                lm_ggml_tensor * b = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 512);
+                op_tensor = lm_ggml_get_rows(ctx, w, b);
+            } break;
+        case LM_GGML_OP_MUL_MAT:
+            {
+                lm_ggml_tensor * b = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
+                op_tensor = lm_ggml_mul_mat(ctx, w, b);
+            } break;
+        case LM_GGML_OP_MUL_MAT_ID:
+            {
+                int n_expert_used = hparams.n_expert_used;
+                lm_ggml_tensor * b = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
+                lm_ggml_tensor * ids = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_I32, n_expert_used, 512);
+                op_tensor = lm_ggml_mul_mat_id(ctx, w, b, ids);
+            } break;
+        case LM_GGML_OP_ADD:
+            {
+                lm_ggml_tensor * a = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, w->ne[0], 512);
+                op_tensor = lm_ggml_add(ctx, a, w);
+            } break;
+        case LM_GGML_OP_MUL:
+            {
+                lm_ggml_tensor * a = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, w->ne[0], 512);
+                op_tensor = lm_ggml_mul(ctx, a, w);
+            } break;
+        case LM_GGML_OP_DIV:
+            {
+                lm_ggml_tensor * a = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, w->ne[0]);
+                op_tensor = lm_ggml_div(ctx, a, w);
+            } break;
+        case LM_GGML_OP_ROPE:
+            {
+                int n_embd_head = hparams.n_embd_head_v;
+                int n_head = hparams.n_head();
+                lm_ggml_tensor * a = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, n_embd_head, n_head, 512);
+                lm_ggml_tensor * b = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 512);
+                op_tensor = lm_ggml_rope_ext(
+                    ctx, a, b, w,
+                    0, 0, 0, 0, 0,
+                    0, 0, 0, 0
+                );
+
+            } break;
+        case LM_GGML_OP_SSM_CONV:
+            {
+                // FIXME
+                lm_ggml_tensor * conv_x = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, 12345, w->ne[1], 6789);
+                op_tensor = lm_ggml_ssm_conv(ctx, conv_x, w);
+            } break;
+        case LM_GGML_OP_SSM_SCAN:
+            {
+                // FIXME
+                const int64_t d_state      = w->ne[0];
+                const int64_t d_inner      = w->ne[1];
+                const int64_t n_seq_tokens = 512;
+                const int64_t n_seqs       = 1;
+                lm_ggml_tensor * s  = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_state, d_inner, n_seqs);
+                lm_ggml_tensor * x = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
+                lm_ggml_tensor * dt = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
+                lm_ggml_tensor * B = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
+                lm_ggml_tensor * C = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
+                op_tensor = lm_ggml_ssm_scan(ctx, s, x, dt, w, B, C);
+            } break;
+        case LM_GGML_OP_RWKV_WKV:
+            {
+                // FIXME
+                const int64_t S = 123;
+                const int64_t H = 123;
+                const int64_t n_tokens = 123;
+                const int64_t n_seqs = 123;
+                lm_ggml_tensor  * k = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, S, 1, H, n_tokens);
+                lm_ggml_tensor  * v = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, 1, S, H, n_tokens);
+                lm_ggml_tensor  * r = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, 1, S, H, n_tokens);
+                lm_ggml_tensor  * tf = w;
+                lm_ggml_tensor  * td = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, 1, S, H, n_tokens);
+                lm_ggml_tensor  * state = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, S, n_seqs, S, H);
+                op_tensor = lm_ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
+            } break;
+        default:
+            LM_GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, lm_ggml_op_name(op), w->name);
+    }
+
+    // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
+    LM_GGML_ASSERT(w->buffer == nullptr);
+    w->buffer = lm_ggml_backend_buft_alloc_buffer(buft, 0);
+    bool op_supported = lm_ggml_backend_dev_supports_op(dev, op_tensor);
+    lm_ggml_backend_buffer_free(w->buffer);
+    w->buffer = nullptr;
+
+    return op_supported;
+}
+
+// find the first buffer type in the list that can use the tensor
+static lm_ggml_backend_buffer_type_t select_weight_buft(const llama_model & model, lm_ggml_tensor * tensor, lm_ggml_op op, const llama_model::buft_list_t & buft_list) {
+    LM_GGML_ASSERT(!buft_list.empty());
+    for (const auto & cur : buft_list) {
+        lm_ggml_backend_dev_t cur_dev = cur.first;
+        lm_ggml_backend_buffer_type_t cur_buft = cur.second;
+        if (weight_buft_supported(model.hparams, tensor, op, cur_buft, cur_dev)) {
+            return cur_buft;
+        }
+    }
+    return nullptr;
+}
+
+// CPU: ACCEL -> CPU extra -> GPU host -> CPU
+static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
+    llama_model::buft_list_t buft_list;
+
+    // add ACCEL buffer types
+    for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
+        lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
+        if (lm_ggml_backend_dev_type(dev) == LM_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
+            auto * buft = lm_ggml_backend_dev_buffer_type(dev);
+            // skip
+            if (buft != lm_ggml_backend_cpu_buffer_type()) {
+                buft_list.emplace_back(dev, buft);
+            }
+        }
+    }
+
+    // add extra buffer types
+    auto * cpu_dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU);
+    auto * cpu_reg = lm_ggml_backend_dev_backend_reg(cpu_dev);
+    auto lm_ggml_backend_dev_get_extra_bufts_fn = (lm_ggml_backend_dev_get_extra_bufts_t)
+        lm_ggml_backend_reg_get_proc_address(cpu_reg, "lm_ggml_backend_cpu_get_extra_bufts");
+    if (lm_ggml_backend_dev_get_extra_bufts_fn) {
+        lm_ggml_backend_buffer_type_t * extra_bufts = lm_ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
+        while (extra_bufts && *extra_bufts) {
+            buft_list.emplace_back(cpu_dev, *extra_bufts);
+            ++extra_bufts;
+        }
+    }
+
+    // add a host buffer type
+    // storing the tensors in a host buffer is useful when the processing of large batches
+    // is offloaded to a GPU device, since it reduces the time spent on data transfers
+    // generally, this will be done using the first device in the list
+    // a better approach would be to handle this on a weight-by-weight basis using the offload_op
+    // function of the device to determine if it would benefit from being stored in a host buffer
+    for (auto * dev : model.devices) {
+        lm_ggml_backend_buffer_type_t buft = lm_ggml_backend_dev_host_buffer_type(dev);
+        if (buft) {
+            buft_list.emplace_back(dev, buft);
+            break;
+        }
+    }
+
+    // add the CPU buffer type
+    for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
+        lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
+        if (lm_ggml_backend_dev_type(dev) == LM_GGML_BACKEND_DEVICE_TYPE_CPU) {
+            buft_list.emplace_back(dev, lm_ggml_backend_dev_buffer_type(dev));
+        }
+    }
+
+    return buft_list;
+}
+
+// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
+static llama_model::buft_list_t make_gpu_buft_list(lm_ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
+    llama_model::buft_list_t buft_list;
+
+    // add the device split buffer type if requested and available
+    if (split_mode == LLAMA_SPLIT_MODE_ROW) {
+        lm_ggml_backend_reg_t reg = lm_ggml_backend_dev_backend_reg(dev);
+        auto lm_ggml_backend_split_buffer_type_fn = (lm_ggml_backend_split_buffer_type_t)
+            lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_split_buffer_type");
+        if (lm_ggml_backend_split_buffer_type_fn) {
+            size_t dev_index = [&]() {
+                auto * reg = lm_ggml_backend_dev_backend_reg(dev);
+                for (size_t i = 0; i < lm_ggml_backend_reg_dev_count(reg); ++i) {
+                    if (lm_ggml_backend_reg_dev_get(reg, i) == dev) {
+                        return i;
+                    }
+                }
+                throw std::runtime_error(format("device %s not found in its backend reg", lm_ggml_backend_dev_name(dev)));
+            }();
+            auto * buft = lm_ggml_backend_split_buffer_type_fn(dev_index, tensor_split);
+            if (buft != nullptr) {
+                buft_list.emplace_back(dev, buft);
+            }
+        }
+    }
+
+    // add the device default buffer type
+    buft_list.emplace_back(dev, lm_ggml_backend_dev_buffer_type(dev));
+
+    return buft_list;
 }
 
 // Returns false if cancelled by progress_callback
@@ -6482,8 +7277,6 @@ static bool llm_load_tensors(
         bool use_mlock,
         llama_progress_callback progress_callback,
         void * progress_callback_user_data) {
-    model.t_start_us = lm_ggml_time_us();
-
     auto & hparams = model.hparams;
 
     model.split_mode   = split_mode;
@@ -6491,116 +7284,93 @@ static bool llm_load_tensors(
     model.n_gpu_layers = n_gpu_layers;
 
     const int n_layer     = hparams.n_layer;
-    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
     bool use_mmap_buffer = true;
 
-    // there is very little benefit to offloading the input layer, so always keep it on the CPU
-    model.buft_input = llama_default_buffer_type_cpu(true);
-    //model.buft_input = llama_default_buffer_type_offload(main_gpu);
-
-    model.buft_layer.resize(n_layer);
-
-    // assign cpu layers
-    for (int i = 0; i < i_gpu_start; ++i) {
-        model.buft_layer[i] = llama_default_buffer_type_cpu(true);
-    }
-
-    if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
-        // calculate the split points
-        int device_count = llama_get_device_count(model);
-        bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
-        std::vector splits(device_count);
-        if (all_zero) {
-            // default split, by free memory
-            for (int i = 0; i < device_count; ++i) {
-                splits[i] = llama_get_device_memory(model, i);
-            }
-        } else {
-            std::copy(tensor_split, tensor_split + device_count, splits.begin());
-        }
-
-        // sum and normalize the splits to get the split points
-        float split_sum = 0.0f;
-        for (int i = 0; i < device_count; ++i) {
-            split_sum += splits[i];
-            splits[i] = split_sum;
-        }
+    // build a list of buffer types for the CPU and GPU devices
+    model.cpu_buft_list = make_cpu_buft_list(model);
+    for (auto * dev : model.devices) {
+        llama_model::buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
+        // add CPU buffer types as a fallback
+        buft_list.insert(buft_list.end(), model.cpu_buft_list.begin(), model.cpu_buft_list.end());
+        model.gpu_buft_list.emplace(dev, std::move(buft_list));
+    }
+
+    // calculate the split points
+    int device_count = llama_get_device_count(model);
+    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
+    std::vector splits(device_count);
+    if (all_zero) {
+        // default split, by free memory
         for (int i = 0; i < device_count; ++i) {
-            splits[i] /= split_sum;
-        }
-
-        // assign the repeating layers to the devices according to the splits
-        int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
-        for (int i = i_gpu_start; i < n_layer; ++i) {
-            int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
-            model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu);
-        }
-        // assign the output layer
-        if (n_gpu_layers > n_layer) {
-            int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
-            model.buft_output = llama_default_buffer_type_offload(model, layer_gpu);
-        } else {
-            model.buft_output = llama_default_buffer_type_cpu(true);
+            lm_ggml_backend_dev_t dev = model.devices[i];
+            size_t total;
+            size_t free;
+            lm_ggml_backend_dev_memory(dev, &free, &total);
+            splits[i] = free;
         }
     } else {
-        lm_ggml_backend_buffer_type_t split_buft;
-        if (split_mode == LLAMA_SPLIT_MODE_ROW) {
-            split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split);
-        } else {
-            // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
-            split_buft = llama_default_buffer_type_offload(model, main_gpu);
-        }
-        // assign the repeating layers
-        for (int i = i_gpu_start; i < n_layer; ++i) {
-            model.buft_layer[i] = {
-                split_buft,
-                llama_default_buffer_type_offload(model, main_gpu)
-            };
-        }
-        // assign the output layer
-        if (n_gpu_layers > n_layer) {
-            model.buft_output = {
-                split_buft,
-                llama_default_buffer_type_offload(model, main_gpu)
-            };
-        } else {
-            model.buft_output = llama_default_buffer_type_cpu(true);
-        }
+        std::copy(tensor_split, tensor_split + device_count, splits.begin());
     }
 
-    // count used buffer types
-    std::map buft_layer_count;
-    buft_layer_count[model.buft_input.buft]++;
-    buft_layer_count[model.buft_input.buft_matrix]++;
-    buft_layer_count[model.buft_output.buft]++;
-    buft_layer_count[model.buft_output.buft_matrix]++;
-    for (int i = 0; i < n_layer; ++i) {
-        buft_layer_count[model.buft_layer[i].buft]++;
-        buft_layer_count[model.buft_layer[i].buft_matrix]++;
+    // sum and normalize the splits to get the split points
+    float split_sum = 0.0f;
+    for (int i = 0; i < device_count; ++i) {
+        split_sum += splits[i];
+        splits[i] = split_sum;
+    }
+    for (int i = 0; i < device_count; ++i) {
+        splits[i] /= split_sum;
     }
 
-    // create one context per buffer type
-    size_t ctx_size = lm_ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
+    lm_ggml_backend_dev_t cpu_dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU);
+    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
+    const int act_gpu_layers = model.devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
+    auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev {
+        if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
+            return {cpu_dev, &model.cpu_buft_list};
+        }
+        int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
+        auto * dev = model.devices.at(layer_gpu);
+        return {dev, &model.gpu_buft_list.at(dev)};
+    };
+
+    // assign the input layer
+    // there is very little benefit to offloading the input layer, so always keep it on the CPU
+    model.dev_input = { cpu_dev, &model.cpu_buft_list };
 
-    // for moe merged tensors
-    ctx_size += lm_ggml_tensor_overhead()*n_layer*3;
+    // assign the repeating layers to the devices according to the splits
+    model.dev_layer.resize(n_layer);
+    for (int il = 0; il < n_layer; ++il) {
+        model.dev_layer[il] = get_layer_buft_list(il);
+    }
+    // assign the output layer
+    model.dev_output = get_layer_buft_list(n_layer);
+
+    // one ggml context per buffer type
+    int max_n_tensors = ml.n_tensors;
+    max_n_tensors += 1;         // duplicated output tensor
+    max_n_tensors += n_layer*2; // duplicated rope freq tensors
+    const size_t ctx_size = lm_ggml_tensor_overhead()*max_n_tensors;
 
     std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        struct lm_ggml_init_params params = {
-            /*.mem_size   =*/ ctx_size,
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        lm_ggml_context * ctx = lm_ggml_init(params);
-        if (!ctx) {
-            throw std::runtime_error(format("failed to create context"));
+    auto ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            lm_ggml_init_params params = {
+                /*.mem_size   =*/ ctx_size,
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            lm_ggml_context * ctx = lm_ggml_init(params);
+            if (!ctx) {
+                throw std::runtime_error(format("failed to create ggml context"));
+            }
+            ctx_map[buft] = ctx;
+            model.ctxs.emplace_back(ctx);
+            return ctx;
         }
-        ctx_map[it.first] = ctx;
-        model.ctxs.push_back(ctx);
-    }
-
-    LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, model.ctxs.size()*ctx_size/1024.0/1024.0);
+        return it->second;
+    };
 
     // create tensors for the weights
     {
@@ -6616,6 +7386,7 @@ static bool llm_load_tensors(
         const int64_t n_embd_gqa    = n_embd_v_gqa;
         const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_vocab_type  = hparams.n_vocab_type;
+        const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
         const int64_t n_expert_used = hparams.n_expert_used;
         const int64_t n_ctx_train   = hparams.n_ctx_train;
@@ -6624,996 +7395,988 @@ static bool llm_load_tensors(
             throw std::runtime_error("model has expert layers but no expert layers are used");
         }
 
-        lm_ggml_context * ctx_input        = ctx_map.at(model.buft_input.buft);
-        lm_ggml_context * ctx_output       = ctx_map.at(model.buft_output.buft);
-        lm_ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
+        int n_moved_tensors = 0;
+        lm_ggml_tensor * first_moved_tensor = nullptr;
+        lm_ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
+        lm_ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
+
+        auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> lm_ggml_tensor * {
+            lm_ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
+
+            if (!t_meta) {
+                if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
+                    return nullptr;
+                }
+                throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
+            }
+
+            // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
+            // the tensor is duplicated
+            // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
+            llm_tensor tn_tensor = tn.tensor;
+            if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & llama_model_loader::TENSOR_DUPLICATED) {
+                tn_tensor = LLM_TENSOR_OUTPUT;
+            }
+
+            auto it = llm_tensor_info_mapping.find(tn_tensor);
+            if (it == llm_tensor_info_mapping.end()) {
+                throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
+            }
+            const auto & info = it->second;
+
+            // tensors with "bias" suffix are always used with LM_GGML_OP_ADD
+            lm_ggml_op op;
+            bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
+            if (bias) {
+                op = LM_GGML_OP_ADD;
+            } else {
+                op = info.op;
+            }
+
+            // sanity checks
+            if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
+                if (tn.bid != -1) {
+                    LM_GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
+                }
+            } else {
+                if (tn.bid == -1) {
+                    LM_GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
+                }
+            }
+
+            // select the buffer type for this tensor
+            llama_model::buft_list_t * buft_list;
+            switch (info.layer) {
+                case LLM_TENSOR_LAYER_INPUT:
+                    buft_list = model.dev_input.buft_list;
+                    break;
+                case LLM_TENSOR_LAYER_OUTPUT:
+                    buft_list = model.dev_output.buft_list;
+                    break;
+                case LLM_TENSOR_LAYER_REPEATING:
+                    buft_list = model.dev_layer.at(tn.bid).buft_list;
+                    break;
+                default:
+                    LM_GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
+            }
+
+            lm_ggml_backend_buffer_type_t buft = select_weight_buft(model, t_meta, op, *buft_list);
+            if (!buft) {
+                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
+            }
+
+            // avoid using a host buffer when using mmap
+            auto * buft_dev = lm_ggml_backend_buft_get_device(buft);
+            if (ml.use_mmap && buft == lm_ggml_backend_dev_host_buffer_type(buft_dev)) {
+                auto * cpu_dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU);
+                buft = lm_ggml_backend_dev_buffer_type(cpu_dev);
+            }
+
+            if (buft != buft_list->front().second) {
+                n_moved_tensors++;
+                if (!first_moved_tensor) {
+                    first_moved_tensor = t_meta;
+                    first_moved_from_buft = buft_list->front().second;
+                    first_moved_to_buft   = buft;
+                }
+            }
+
+            lm_ggml_context * ctx = ctx_for_buft(buft);
 
-        auto ctx_for_layer       = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
-        auto ctx_for_layer_split = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
+            // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
+            if (flags & llama_model_loader::TENSOR_DUPLICATED) {
+                lm_ggml_tensor * t = lm_ggml_get_tensor(ctx, tn.str().c_str());
+                if (t) {
+                    return t;
+                }
+            }
+            return ml.create_tensor(ctx, tn, ne, flags);
+        };
 
         model.layers.resize(n_layer);
 
+        // TODO: move to a separate function
         const auto tn = LLM_TN(model.arch);
         switch (model.arch) {
             case LLM_ARCH_LLAMA:
             case LLM_ARCH_REFACT:
             case LLM_ARCH_MINICPM:
+            case LLM_ARCH_GRANITE:
+            case LLM_ARCH_GRANITE_MOE:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
 
                         // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
 
                         if (n_expert == 0) {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
 
                             // optional MLP bias
-                            layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         } else {
-                            layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
-                            layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            if (layer.ffn_gate_exps) {
-                                layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                                layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
-                            } else {
-                                // merge split expert into a single tensor for compatibility with older models
-                                // requires disabling mmap
-                                use_mmap_buffer = false;
-
-                                lm_ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
-                                lm_ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
-                                lm_ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
-
-                                layer.ffn_gate_exps = lm_ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
-                                layer.ffn_down_exps = lm_ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
-                                layer.ffn_up_exps   = lm_ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
-
-                                lm_ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
-                                lm_ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
-                                lm_ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
-
-                                for (uint32_t x = 0; x < n_expert; ++x) {
-                                    // the individual experts are loaded into a view of the merged tensor
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
-                                }
-                            }
+                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
                         }
                     }
                 } break;
+            case LLM_ARCH_MINICPM3:
+                {
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
+
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
+
+                        layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                        layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                    }
+                } break;
             case LLM_ARCH_GROK:
                 {
                     if (n_expert == 0) {
                         throw std::runtime_error("Grok model cannot have zero experts");
                     }
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
 
-                        if (layer.ffn_gate_exps) {
-                            layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                            layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
-                        } else {
-                            // merge split expert into a single tensor for compatibility with older models
-                            // requires disabling mmap
-                            use_mmap_buffer = false;
-
-                            lm_ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
-                            lm_ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
-                            lm_ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
-
-                            layer.ffn_gate_exps = lm_ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
-                            layer.ffn_down_exps = lm_ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
-                            layer.ffn_up_exps   = lm_ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
-
-                            lm_ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
-                            lm_ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
-                            lm_ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
-
-                            for (uint32_t x = 0; x < n_expert; ++x) {
-                                // the individual experts are loaded into a view of the merged tensor
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
-                            }
-                        }
-
-                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_DBRX:
-            {
-                if (n_expert == 0) {
-                    throw std::runtime_error("DBRX model cannot have zero experts");
-                }
-
-                model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                // output
                 {
-                    model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                }
+                    if (n_expert == 0) {
+                        throw std::runtime_error("DBRX model cannot have zero experts");
+                    }
 
-                for (int i = 0; i < n_layer; ++i) {
-                    lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                    lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-                    auto & layer = model.layers[i];
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
-                    layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
 
-                    layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                    layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                    layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                    layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
-                    layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
-                    layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert});
-                    layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert});
-                }
-            } break;
+                        layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                    }
+                } break;
             case LLM_ARCH_BAICHUAN:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
                     {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_FALCON:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
 
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
+                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                         }
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_STARCODER:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                        model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
                             // needs to be on GPU
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                         }
 
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
+                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_BERT:
             case LLM_ARCH_NOMIC_BERT:
                 {
-                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab});
-                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
+                    model.tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
+                    model.type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
 
                     if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train});
+                        model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train}, 0);
+
+                        model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        model.cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
 
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
+                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd}, 0);
 
-                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa}, 0);
 
-                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa}, 0);
                         } else {
-                            layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
                         }
 
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff});
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
+                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
                         } else {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
                         }
 
-                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd});
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_JINA_BERT_V2:
                 {
-                    model.tok_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
-                    model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings
+                    model.tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
+                    model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
 
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
+                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
+                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
 
+                    model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i]; // JinaBertLayer
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
 
-                        layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa});
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
 
-                        layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa});
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
 
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0); //output_dens
 
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd});
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
+                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd}, 0);
 
-                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
 
-                        layer.layer_out_norm   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd});
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_BLOOM:
                 {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
+                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
+                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_MPT:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    if (!model.output) {
+                        model.output    = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // AWQ ScaleActivation layer
-                        layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
                 } break;
             case LLM_ARCH_STABLELM:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm =   ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+                        layer.attn_norm =   create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
                         // optional bias tensors, present in Stable LM 2 1.6B
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_QWEN:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3});
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2}, 0);
                     }
                 } break;
             case LLM_ARCH_QWEN2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
                         // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_QWEN2MOE:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
                         // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
 
-                        LM_GGML_ASSERT(n_expert      > 0);
-                        LM_GGML_ASSERT(n_expert_used > 0);
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
+                        }
 
                         // MoE branch
                         const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
-                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
-                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
 
                         // Shared expert branch
                         const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
 
-                        layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
-                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp});
-                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd});
-                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp});
+                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp}, 0);
+                        layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd}, 0);
+                        layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp}, 0);
                     }
                 } break;
             case LLM_ARCH_PHI2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                        model.output_b      = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         if (layer.wqkv == nullptr) {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
-                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
 
-                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa});
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa}, 0);
 
-                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa});
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa}, 0);
                         }
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_PHI3:
                 {
                     const int64_t n_embd_head = n_embd / n_head;
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd });
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd });
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd });
-                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff });
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
 
-                        layer.rope_long  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                     }
                 } break;
             case LLM_ARCH_PLAMO:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_GPT2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_CODESHELL:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_ORION:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-                    for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
+                    for (int i = 0; i < n_layer; ++i) {
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_INTERNLM2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        // layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_GEMMA:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_GEMMA2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
-                        layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_STARCODER2:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
 
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
                         // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
 
                         // optional bias tensors
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff});
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_MAMBA:
@@ -7624,246 +8387,252 @@ static bool llm_load_tensors(
                     const int64_t dt_rank = hparams.ssm_dt_rank;
 
                     // only an expansion factor of 2 is supported for now
-                    LM_GGML_ASSERT(2 * n_embd == d_inner);
+                    if (2 * n_embd != d_inner) {
+                        throw std::runtime_error("only an expansion factor of 2 is supported for now");
+                    }
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
 
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed, duplicated to allow offloading
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed, duplicated to allow offloading
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
                         // norm
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
+                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
 
-                        layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
-                        layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
+                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
+                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
 
-                        layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
+                        layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
 
-                        layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
-                        layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
+                        layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
+                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
 
                         // no "weight" suffix for these
-                        layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
-                        layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
+                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
+                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
 
                         // out_proj
-                        layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
+                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_XVERSE:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-                    for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
+                    for (int i = 0; i < n_layer; ++i) {
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_COMMAND_R:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        // init output from the input tok embed
-                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
                         if (n_layer >= 64){
-                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head});
-                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv});
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
                         }
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+                        auto & layer = model.layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OLMOE:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
+                    for (int i = 0; i < n_layer; ++i) {
                         auto & layer = model.layers[i];
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0");
+                        }
+
+                        // MoE branch
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
                     }
                 } break;
             case LLM_ARCH_OPENELM:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        // init output from the input tok embed
-                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
 
                     for (int i = 0; i < n_layer; ++i) {
                         const int64_t n_head      =   hparams.n_head(i);
                         const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
                         const int64_t n_ff        =   hparams.n_ff(i);
 
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k});
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_GPTNEOX:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_ARCTIC:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd});
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-                        layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd});
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
-                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
                     }
                 } break;
             case LLM_ARCH_DEEPSEEK2:
@@ -7879,340 +8648,403 @@ static bool llm_load_tensors(
                     const int64_t n_ff_exp        = hparams.n_ff_exp;
                     const int64_t n_expert_shared = hparams.n_expert_shared;
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
                         if (!is_lite) {
-                            layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
+                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
                         }
 
-                        layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
 
                         if (!is_lite) {
-                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
-                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
+                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
                         } else {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
                         }
 
-                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
-                        layer.wkv_b     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
-                        layer.wo        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd});
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
                         if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                         } else {
-                            layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
 
-                            LM_GGML_ASSERT(n_expert      > 0);
-                            LM_GGML_ASSERT(n_expert_used > 0);
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
 
                             // MoE branch
-                            layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
-                            layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
-                            layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
 
                             // Shared expert branch
-                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared});
-                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd});
-                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared});
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
                         }
                     }
                 } break;
             case LLM_ARCH_BITNET:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    }
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd});
-                        layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
-
-                        layer.wq       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1});
-                        layer.wk       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1});
-                        layer.wv       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1});
-                        layer.wo       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1});
-
-                        layer.ffn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd});
-                        layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
-
-                        layer.ffn_gate       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1});
-                        layer.ffn_down       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1});
-                        layer.ffn_up         = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1});
+                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
+
+                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
                 } break;
             case LLM_ARCH_T5:
                 {
                     const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm     = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
 
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
 
-                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
 
-                        layer.attn_norm  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
 
-                        layer.attn_norm_cross  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
                         // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_T5ENCODER:
                 {
                     const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
 
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
+                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
 
-                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_JAIS:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-                    // Output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    // output
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff});
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
                     }
                 } break;
             case LLM_ARCH_CHATGLM:
                 {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + (hparams.n_embd_head_k << 2)});
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2});
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
                     }
                 } break;
             case LLM_ARCH_NEMOTRON:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,   tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split,  tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
-                    }
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
                         // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
 
                         // optional MLP bias
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
                 } break;
             case LLM_ARCH_EXAONE:
                 {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // Block 0, LN0
+                    model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
+                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
+                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
+                    }
+
+                } break;
+            case LLM_ARCH_CHAMELEON:
+                {
+                 model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                 // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (model.output == NULL) {
+                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
-
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
+
+        if (n_moved_tensors > 0) {
+            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
+                __func__, first_moved_tensor->name, lm_ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
+                lm_ggml_backend_buft_name(first_moved_from_buft), lm_ggml_backend_buft_name(first_moved_to_buft));
+        }
     }
 
     ml.done_getting_tensors();
@@ -8225,67 +9057,54 @@ static bool llm_load_tensors(
     ctx_bufs.reserve(ctx_map.size());
 
     // Ensure we have enough capacity for the maximum backend buffer we will potentially create
-    size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
+    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
     model.bufs.reserve(n_max_backend_buffer);
 
     for (auto & it : ctx_map) {
         lm_ggml_backend_buffer_type_t buft = it.first;
         lm_ggml_context * ctx              = it.second;
 
+        // skip contexts without tensors
+        if (lm_ggml_get_first_tensor(ctx) == nullptr) {
+            continue;
+        }
+
         llama_buf_map bufs;
         bufs.reserve(n_max_backend_buffer);
 
-        // only the mmap region containing the tensors in the model is mapped to the backend buffer
-        // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-        // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-        if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) {
+        // check if it is possible to use buffer_from_host_ptr with this buffer type
+        lm_ggml_backend_dev_t dev = lm_ggml_backend_buft_get_device(buft);
+        lm_ggml_backend_dev_props props;
+        lm_ggml_backend_dev_get_props(dev, &props);
+        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
+        bool is_default_buft = buft == lm_ggml_backend_dev_buffer_type(dev);
+
+        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
             for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                // only the mmap region containing the tensors in the model is mapped to the backend buffer
+                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
+                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
                 void * addr = nullptr;
-                size_t first, last;
+                size_t first, last; // NOLINT
                 ml.get_mapping_range(&first, &last, &addr, idx, ctx);
                 if (first >= last) {
                     continue;
                 }
-                lm_ggml_backend_buffer_t buf = lm_ggml_backend_cpu_buffer_from_ptr((char *) addr + first, last - first);
-                if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend CPU buffer");
-                }
-                model.bufs.push_back(buf);
-                bufs.emplace(idx, buf);
-#ifdef LM_GGML_USE_CUDA
-                if (n_layer >= n_gpu_layers) {
-                    lm_ggml_backend_cuda_register_host_buffer(
-                        lm_ggml_backend_buffer_get_base(buf),
-                        lm_ggml_backend_buffer_get_size(buf));
-                }
-#endif
-            }
-        }
-#ifdef LM_GGML_USE_METAL
-        else if (ml.use_mmap && use_mmap_buffer && buft == lm_ggml_backend_metal_buffer_type()) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
                 const size_t max_size = lm_ggml_get_max_tensor_size(ctx);
-                void * addr = nullptr;
-                size_t first, last;
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                lm_ggml_backend_buffer_t buf = lm_ggml_backend_metal_buffer_from_ptr((char *) addr + first, last - first, max_size);
+                lm_ggml_backend_buffer_t buf = lm_ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
                 if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend metal buffer");
+                    throw std::runtime_error(format("unable to allocate %s buffer", lm_ggml_backend_buft_name(buft)));
                 }
-                model.bufs.push_back(buf);
+                model.bufs.emplace_back(buf);
                 bufs.emplace(idx, buf);
             }
         }
-#endif
         else {
             lm_ggml_backend_buffer_t buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
             if (buf == nullptr) {
-                throw std::runtime_error("unable to allocate backend buffer");
+                throw std::runtime_error(format("unable to allocate %s buffer", lm_ggml_backend_buft_name(buft)));
             }
-            model.bufs.push_back(buf);
+            model.bufs.emplace_back(buf);
             if (use_mlock && lm_ggml_backend_buffer_is_host(buf)) {
                 model.mlock_bufs.emplace_back(new llama_mlock);
                 auto & mlock_buf = model.mlock_bufs.back();
@@ -8303,7 +9122,7 @@ static bool llm_load_tensors(
 
         for (auto & buf : bufs) {
             // indicate that this buffer contains weights
-            // this is used by lm_ggml_backend_sched to improve op scheduling -> ops that use a weight are preferably scheduled to the backend that contains the weight
+            // this is used by lm_ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
             lm_ggml_backend_buffer_set_usage(buf.second, LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
         }
 
@@ -8315,7 +9134,7 @@ static bool llm_load_tensors(
 
         LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
         if (n_gpu_layers > (int) hparams.n_layer) {
-            LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__);
+            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
         }
 
         const int max_backend_supported_layers = hparams.n_layer + 1;
@@ -8324,14 +9143,14 @@ static bool llm_load_tensors(
         LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
     }
 
-    // print memory requirements
-    for (lm_ggml_backend_buffer_t buf : model.bufs) {
-        LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
+    // print memory requirements per buffer type
+    for (auto & buf : model.bufs) {
+        LLAMA_LOG_INFO("%s: %10s model buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf.get()), lm_ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
     }
 
     // populate tensors_by_name
-    for (lm_ggml_context * ctx : model.ctxs) {
-        for (auto * cur = lm_ggml_get_first_tensor(ctx); cur != NULL; cur = lm_ggml_get_next_tensor(ctx, cur)) {
+    for (auto & ctx : model.ctxs) {
+        for (auto * cur = lm_ggml_get_first_tensor(ctx.get()); cur != NULL; cur = lm_ggml_get_next_tensor(ctx.get(), cur)) {
             model.tensors_by_name.emplace_back(lm_ggml_get_name(cur), cur);
         }
     }
@@ -8351,14 +9170,13 @@ static bool llm_load_tensors(
         }
     }
 
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = lm_ggml_time_us() - model.t_start_us;
     return true;
 }
 
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
+    model.t_start_us = lm_ggml_time_us();
+
     try {
         llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
 
@@ -8392,23 +9210,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
             return 0;
         }
 
-#ifdef LM_GGML_USE_KOMPUTE
-        if (params.n_gpu_layers > 0 && (
-            !(model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON)
-            || !(
-                model.ftype == LLAMA_FTYPE_ALL_F32 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
-            )
-        )) {
-            // TODO(cebtenzzre): propagate this error outside of llama_load_model_from_file
-            LLAMA_LOG_WARN("%s: disabling Kompute due to unsupported model arch or quantization\n", __func__);
-            params.n_gpu_layers = 0;
-        }
-#endif
-
         if (!llm_load_tensors(
             ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
             params.progress_callback, params.progress_callback_user_data
@@ -8420,6 +9221,10 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         return -1;
     }
 
+    // loading time will be recalculate after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = lm_ggml_time_us() - model.t_start_us;
+
     return 0;
 }
 
@@ -8470,6 +9275,11 @@ static struct lm_ggml_tensor * llm_build_inp_embd(
         lm_ggml_set_input(lctx.inp_embd);
     }
 
+    // For Granite architecture
+    if (hparams.f_embedding_scale != 0.0f) {
+        inpL = lm_ggml_scale(ctx, inpL, hparams.f_embedding_scale);
+    }
+
     cb(inpL, "inp_embd", -1);
 
     return inpL;
@@ -8494,8 +9304,7 @@ static void llm_build_kv_store(
 
     LM_GGML_ASSERT(kv.size == n_ctx);
 
-    struct lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
-            (lm_ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
+    struct lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, lm_ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
     // note: storing RoPE-ed version of K in the KV cache
@@ -8506,8 +9315,7 @@ static void llm_build_kv_store(
     struct lm_ggml_tensor * v_cache_view = nullptr;
 
     if (cparams.flash_attn) {
-        v_cache_view = lm_ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
-                (kv_head)*lm_ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
+        v_cache_view = lm_ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, lm_ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
     } else {
         // note: the V cache is transposed when not using flash attention
         v_cache_view = lm_ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
@@ -8888,20 +9696,16 @@ static struct lm_ggml_tensor * llm_build_kqv(
         cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
-            lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
-        }
+        lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
 
         cur = lm_ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
     } else {
         struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
-            // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-            // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-            lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
-        }
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
 
         if (model.arch == LLM_ARCH_GROK) {
             // need to do the following:
@@ -8910,9 +9714,6 @@ static struct lm_ggml_tensor * llm_build_kqv(
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            //try from phi2
-            //lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
-
             kq = lm_ggml_tanh(ctx, lm_ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
             kq = lm_ggml_scale(ctx, kq, 30);
         }
@@ -8994,8 +9795,7 @@ static struct lm_ggml_tensor * llm_build_kv(
 
     struct lm_ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
-            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
+    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -9023,7 +9823,7 @@ static struct lm_ggml_tensor * llm_build_copy_mask_state(
     // FIXME: zero-out NANs?
     states = lm_ggml_mul(ctx, states, state_mask);
 
-    // copy states which won't be changed further (between n_seqs and n_rs)
+    // copy states which won't be changed further (between n_seqs and n_kv)
     lm_ggml_build_forward_expand(graph,
         lm_ggml_cpy(ctx,
             lm_ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*lm_ggml_element_size(states)),
@@ -9169,12 +9969,177 @@ static struct lm_ggml_tensor * llm_build_mamba(
     return cur;
 }
 
+static struct lm_ggml_tensor * llm_build_rwkv6_time_mix(
+        struct llama_context & lctx,
+        struct lm_ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct lm_ggml_tensor * cur,
+        struct lm_ggml_tensor * x_prev,
+        struct lm_ggml_tensor ** wkv_state) {
+    size_t n_embd       = cur->ne[0];
+    size_t n_seq_tokens = cur->ne[1];
+    size_t n_seqs       = cur->ne[2];
+
+    size_t head_size  = layer->time_mix_first->ne[0];
+    size_t head_count = layer->time_mix_first->ne[1];
+
+    size_t n_tokens = n_seqs * n_seq_tokens;
+
+    struct lm_ggml_tensor * sx = lm_ggml_sub(ctx, x_prev, cur);
+
+    sx  = lm_ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
+    cur = lm_ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+
+    struct lm_ggml_tensor * xxx = lm_ggml_add(ctx, lm_ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
+
+    xxx = lm_ggml_reshape_4d(
+        ctx,
+        lm_ggml_tanh(
+            ctx,
+            lm_ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
+        ),
+        layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+    );
+
+    xxx = lm_ggml_cont(ctx, lm_ggml_permute(ctx, xxx, 0, 1, 3, 2));
+
+    xxx = lm_ggml_mul_mat(
+        ctx,
+        lm_ggml_reshape_4d(
+            ctx,
+            layer->time_mix_w2,
+            layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
+        ),
+        xxx
+    );
+
+    struct lm_ggml_tensor *mw = lm_ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+    struct lm_ggml_tensor *mk = lm_ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+    struct lm_ggml_tensor *mv = lm_ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+    struct lm_ggml_tensor *mr = lm_ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+    struct lm_ggml_tensor *mg = lm_ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+
+    struct lm_ggml_tensor * xw = lm_ggml_add(
+        ctx,
+        lm_ggml_mul(
+            ctx,
+            lm_ggml_add(ctx, mw, layer->time_mix_lerp_w),
+            sx
+        ),
+        cur
+    );
+
+    struct lm_ggml_tensor * xk = lm_ggml_add(
+        ctx,
+        lm_ggml_mul(
+            ctx,
+            lm_ggml_add(ctx, mk, layer->time_mix_lerp_k),
+            sx
+        ),
+        cur
+    );
+
+    struct lm_ggml_tensor * xv = lm_ggml_add(
+        ctx,
+        lm_ggml_mul(
+            ctx,
+            lm_ggml_add(ctx, mv, layer->time_mix_lerp_v),
+            sx
+        ),
+        cur
+    );
+
+    struct lm_ggml_tensor * xr = lm_ggml_add(
+        ctx,
+        lm_ggml_mul(
+            ctx,
+            lm_ggml_add(ctx, mr, layer->time_mix_lerp_r),
+            sx
+        ),
+        cur
+    );
+
+    struct lm_ggml_tensor * xg = lm_ggml_add(
+        ctx,
+        lm_ggml_mul(
+            ctx,
+            lm_ggml_add(ctx, mg, layer->time_mix_lerp_g),
+            sx
+        ),
+        cur
+    );
+
+    struct lm_ggml_tensor * r = lm_ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
+    struct lm_ggml_tensor * k = lm_ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
+    struct lm_ggml_tensor * v = lm_ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
+    struct lm_ggml_tensor * g = lm_ggml_silu(
+        ctx,
+        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
+    );
+
+    struct lm_ggml_tensor * w = lm_ggml_mul_mat(
+        ctx,
+        layer->time_mix_decay_w2,
+        lm_ggml_tanh(
+            ctx,
+            lm_ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
+        )
+    );
+
+    w = lm_ggml_add(ctx, w, lm_ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = lm_ggml_exp(ctx, lm_ggml_neg(ctx, lm_ggml_exp(ctx, w)));
+    w = lm_ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+
+    k = lm_ggml_transpose(ctx, k);
+    v = lm_ggml_transpose(ctx, v);
+    r = lm_ggml_transpose(ctx, r);
+
+    struct lm_ggml_tensor * wkv_output = lm_ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    cur = lm_ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
+    *wkv_state = lm_ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
+
+    // group norm with head_count groups
+    cur = lm_ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+    cur = lm_ggml_norm(ctx, cur, 64e-5f);
+
+    // Convert back to regular vectors.
+    cur = lm_ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    cur = lm_ggml_add(ctx, lm_ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+
+    cur = lm_ggml_mul(ctx, cur, g);
+    cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
+
+    return lm_ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
+}
+
+static struct lm_ggml_tensor * llm_build_rwkv6_channel_mix(
+        struct llama_context & lctx,
+        struct lm_ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct lm_ggml_tensor * cur,
+        struct lm_ggml_tensor * x_prev) {
+    struct lm_ggml_tensor * sx = lm_ggml_sub(ctx, x_prev, cur);
+    struct lm_ggml_tensor * xk = lm_ggml_add(ctx, lm_ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
+    struct lm_ggml_tensor * xr = lm_ggml_add(ctx, lm_ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
+
+    struct lm_ggml_tensor * r = lm_ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
+    struct lm_ggml_tensor * k = lm_ggml_sqr(
+        ctx,
+        lm_ggml_relu(
+            ctx,
+            llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
+        )
+    );
+
+    return lm_ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
+}
+
 struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
-    const llama_ubatch   & batch;
+    const llama_ubatch   & ubatch;
     const llama_kv_cache & kv_self;
 
     const int64_t n_embd;
@@ -9220,14 +10185,14 @@ struct llm_build_context {
     // TODO: consider making the entire interface noexcept
     llm_build_context(
         llama_context  & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
         lctx             (lctx),
         hparams          (model.hparams),
         cparams          (lctx.cparams),
-        batch            (batch),
+        ubatch           (ubatch),
         kv_self          (lctx.kv_self),
         n_embd           (hparams.n_embd),
         n_layer          (hparams.n_layer),
@@ -9249,7 +10214,7 @@ struct llm_build_context {
         beta_slow        (cparams.yarn_beta_slow),
         norm_eps         (hparams.f_norm_eps),
         norm_rms_eps     (hparams.f_norm_rms_eps),
-        n_tokens         (batch.n_tokens),
+        n_tokens         (ubatch.n_tokens),
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
         n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
@@ -9290,10 +10255,8 @@ struct llm_build_context {
     }
 
     void free() {
-        if (ctx0) {
-            lm_ggml_free(ctx0);
-            ctx0 = nullptr;
-        }
+        lm_ggml_free(ctx0);
+        ctx0 = nullptr;
     }
 
     struct lm_ggml_cgraph * build_k_shift() {
@@ -9309,17 +10272,36 @@ struct llm_build_context {
             const int64_t n_head_kv = hparams.n_head_kv(il);
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct lm_ggml_tensor * rope_factors = build_rope_factors(il);
-            struct lm_ggml_tensor * tmp =
+            struct lm_ggml_tensor * k =
+                lm_ggml_view_3d(ctx0, kv_self.k_l[il],
+                    n_embd_head_k, n_head_kv, n_ctx,
+                    lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                    lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                    0);
+
+            struct lm_ggml_tensor * tmp;
+            if (lm_ggml_is_quantized(k->type)) {
+                // dequantize to f32 -> RoPE -> quantize back
+                tmp = lm_ggml_cast(ctx0, k, LM_GGML_TYPE_F32);
+                cb(tmp, "K_f32", il);
+                for (auto & backend : lctx.backends) {
+                    // Figure out which backend KV cache belongs to
+                    if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
+                        lm_ggml_backend_sched_set_tensor_backend(lctx.sched.get(), tmp, backend.get());
+                        break;
+                    }
+                }
+                tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(tmp, "K_shifted_f32", il);
+                tmp = lm_ggml_cpy(ctx0, tmp, k);
+            } else {
                 // we rotate only the first n_rot dimensions
-                lm_ggml_rope_ext_inplace(ctx0,
-                        lm_ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_head_kv, n_ctx,
-                            lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            0),
+                tmp = lm_ggml_rope_ext_inplace(ctx0, k,
                         lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-
+            }
             cb(tmp, "K_shifted", il);
             lm_ggml_build_forward_expand(gf, tmp);
         }
@@ -9477,8 +10459,8 @@ struct llm_build_context {
     struct lm_ggml_cgraph * append_pooling(struct lm_ggml_cgraph * gf) {
         // find result_norm tensor for input
         struct lm_ggml_tensor * inp = nullptr;
-        for (int i = gf->n_nodes - 1; i >= 0; --i) {
-            inp = gf->nodes[i];
+        for (int i = lm_ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+            inp = lm_ggml_graph_node(gf, i);
             if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
                 break;
             } else {
@@ -9490,6 +10472,10 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
 
         switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
             case LLAMA_POOLING_TYPE_MEAN:
                 {
                     struct lm_ggml_tensor * inp_mean = build_inp_mean();
@@ -9501,9 +10487,26 @@ struct llm_build_context {
                     struct lm_ggml_tensor * inp_cls = build_inp_cls();
                     cur = lm_ggml_get_rows(ctx0, inp, inp_cls);
                 } break;
-            case LLAMA_POOLING_TYPE_NONE:
+            case LLAMA_POOLING_TYPE_RANK:
                 {
-                    cur = inp;
+                    struct lm_ggml_tensor * inp_cls = build_inp_cls();
+                    inp = lm_ggml_get_rows(ctx0, inp, inp_cls);
+
+                    // classification head
+                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                    LM_GGML_ASSERT(model.cls       != nullptr);
+                    LM_GGML_ASSERT(model.cls_b     != nullptr);
+
+                    cur = lm_ggml_add (ctx0, lm_ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
+                    cur = lm_ggml_tanh(ctx0, cur);
+
+                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                    if (model.cls_out) {
+                        LM_GGML_ASSERT(model.cls_out_b != nullptr);
+
+                        cur = lm_ggml_add (ctx0, lm_ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
+                    }
                 } break;
             default:
                 {
@@ -9578,7 +10581,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -9586,6 +10589,7 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
             struct lm_ggml_tensor * inpSA = inpL;
 
@@ -9638,7 +10642,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9649,6 +10653,11 @@ struct llm_build_context {
                 inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = lm_ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
@@ -9685,6 +10694,11 @@ struct llm_build_context {
                 cb(cur, "ffn_moe_out", il);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = lm_ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -9704,6 +10718,12 @@ struct llm_build_context {
 
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = lm_ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+        }
+
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9721,7 +10741,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
@@ -9836,7 +10856,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -9940,7 +10960,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -10062,7 +11082,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // multiply by embedding_multiplier_scale of 78.38367176906169
         inpL = lm_ggml_scale(ctx0, inpL, 78.38367176906169f);
@@ -10220,7 +11240,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -10342,7 +11362,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -10445,7 +11465,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -10547,7 +11567,7 @@ struct llm_build_context {
         }
 
         // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // token types are hardcoded to zero ("Sentence A")
         struct lm_ggml_tensor * type_row0 = lm_ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
@@ -10715,8 +11735,8 @@ struct llm_build_context {
             inpL = cur;
         }
 
-        // final output
         cur = inpL;
+
         cb(cur, "result_embd", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10734,7 +11754,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -10836,7 +11856,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * pos;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -10974,7 +11994,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11124,7 +12144,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11237,7 +12257,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11352,7 +12372,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11497,7 +12517,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * ffn_output;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11616,7 +12636,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11744,7 +12764,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11849,7 +12869,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * pos;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -11954,7 +12974,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -12064,7 +13084,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -12182,7 +13202,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -12296,20 +13316,164 @@ struct llm_build_context {
     struct lm_ggml_cgraph * build_minicpm() {
         struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        const int64_t n_embd = hparams.n_embd;
+        //TODO: if the model varies, these parameters need to be read from the model
+        const int64_t n_embd_base = 256;
+        const float scale_embd  = 12.0f;
+        const float scale_depth = 1.4f;
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // scale the input embeddings
+        inpL = lm_ggml_scale(ctx0, inpL, scale_embd);
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct lm_ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // scale_res - scale the hidden states for residual connection
+            const float scale_res = scale_depth/sqrtf(float(n_layer));
+            cur = lm_ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled", -1);
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // scale the hidden states for residual connection
+            cur = lm_ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled_ffn", -1);
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head scaling
+        const float scale_lmhead = float(n_embd_base)/float(n_embd);
+        cur = lm_ggml_scale(ctx0, cur, scale_lmhead);
+        cb(cur, "lmhead_scaling", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_minicpm3() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-        const int64_t n_embd = hparams.n_embd;
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
         const float scale_embd  = 12.0f;
         const float scale_depth = 1.4f;
+        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // scale the input embeddings
         inpL = lm_ggml_scale(ctx0, inpL, scale_embd);
@@ -12324,53 +13488,118 @@ struct llm_build_context {
         for (int il = 0; il < n_layer; ++il) {
             struct lm_ggml_tensor * inpSA = inpL;
 
+            struct lm_ggml_tensor * rope_factors = build_rope_factors(il);
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            // self-attention
+            // self_attention
             {
-                // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
+                struct lm_ggml_tensor * q = NULL;
+                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                q = lm_ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                cb(q, "q", il);
 
-                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
+                q = llm_build_norm(ctx0, q, hparams,
+                        model.layers[il].attn_q_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(q, "q", il);
 
-                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
+                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                q = lm_ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                cb(q, "q", il);
 
-                Qcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct lm_ggml_tensor * q_nope = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k),
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                struct lm_ggml_tensor * q_pe = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k),
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        lm_ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                struct lm_ggml_tensor * kv_pe_compresseed = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                struct lm_ggml_tensor * kv_compressed = lm_ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                struct lm_ggml_tensor * k_pe = lm_ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        lm_ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                kv_compressed = lm_ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                struct lm_ggml_tensor * kv = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct lm_ggml_tensor * k_nope = lm_ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        lm_ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        lm_ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                struct lm_ggml_tensor * v_states = lm_ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = lm_ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = lm_ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                    lm_ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                    0);
+                cb(v_states, "v_states", il);
+
+                q_pe = lm_ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = lm_ggml_rope_ext(
+                    ctx0, q_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Qcur, "Qcur", il);
+                cb(q_pe, "q_pe", il);
 
-                Kcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                // shared RoPE key
+                k_pe = lm_ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = lm_ggml_rope_ext(
+                    ctx0, k_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Kcur, "Kcur", il);
+                cb(k_pe, "k_pe", il);
+
+                struct lm_ggml_tensor * q_states = lm_ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                struct lm_ggml_tensor * k_states = lm_ggml_concat(ctx0, k_nope, lm_ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        model.layers[il].wo, NULL,
+                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -12383,7 +13612,7 @@ struct llm_build_context {
             // scale_res - scale the hidden states for residual connection
             const float scale_res = scale_depth/sqrtf(float(n_layer));
             cur = lm_ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled", -1);
+            cb(cur, "hidden_scaled", il);
 
             struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
@@ -12406,7 +13635,7 @@ struct llm_build_context {
 
             // scale the hidden states for residual connection
             cur = lm_ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled_ffn", -1);
+            cb(cur, "hidden_scaled_ffn", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cur = lctx.cvec.apply_to(ctx0, cur, il);
@@ -12445,7 +13674,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = lm_ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -12553,7 +13782,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = lm_ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -12691,7 +13920,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -12807,7 +14036,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         struct lm_ggml_tensor * state_copy = build_inp_s_copy();
         struct lm_ggml_tensor * state_mask = build_inp_s_mask();
@@ -12819,7 +14048,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+            cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
                     state_copy, state_mask,
                     kv_head, n_kv, cb, il);
 
@@ -12865,7 +14094,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -12971,9 +14200,145 @@ struct llm_build_context {
                 cb(cur, "ffn_out", il);
             }
 
-            // add together residual + FFN + self-attention
-            cur = lm_ggml_add(ctx0, cur, inpL);
-            cur = lm_ggml_add(ctx0, cur, attn_out);
+            // add together residual + FFN + self-attention
+            cur = lm_ggml_add(ctx0, cur, inpL);
+            cur = lm_ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (f_logit_scale) {
+            cur = lm_ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+
+    }
+
+    // ref: https://allenai.org/olmo
+    // based on the original build_llama() function, changes:
+    //   * non-parametric layer norm
+    //   * clamp qkv
+    //   * removed bias
+    //   * removed MoE
+    struct lm_ggml_cgraph * build_olmo() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct lm_ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    NULL, NULL,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Qcur = lm_ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Kcur = lm_ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Vcur = lm_ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, nullptr,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    NULL, NULL,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
             cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
@@ -12984,32 +14349,24 @@ struct llm_build_context {
         cur = inpL;
 
         cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
+                NULL, NULL,
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (f_logit_scale) {
-            cur = lm_ggml_scale(ctx0, cur, f_logit_scale);
-        }
-
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
 
         return gf;
-
     }
 
-    // ref: https://allenai.org/olmo
-    // based on the original build_llama() function, changes:
-    //   * non-parametric layer norm
-    //   * clamp qkv
+    // based on the build_qwen2moe() function, changes:
+    //   * removed shared experts
     //   * removed bias
-    //   * removed MoE
-    struct lm_ggml_cgraph * build_olmo() {
+    //   * added q, k norm
+    struct lm_ggml_cgraph * build_olmoe() {
         struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -13022,7 +14379,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13035,50 +14392,49 @@ struct llm_build_context {
 
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            // self-attention
+            // self_attention
             {
                 // compute Q and K and RoPE them
                 struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Qcur = lm_ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Qcur, "Qcur", il);
-                }
 
                 struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Kcur = lm_ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Kcur, "Kcur", il);
-                }
 
                 struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Vcur = lm_ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Vcur, "Vcur", il);
-                }
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    ctx0, Qcur, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Qcur, "Qcur", il);
+                cb(Qcur, "Qcur_rope", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    ctx0, Kcur, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Kcur, "Kcur", il);
+                cb(Kcur, "Kcur_rope", il);
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, nullptr,
+                        model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
@@ -13093,23 +14449,24 @@ struct llm_build_context {
             struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
-            // feed-forward network
+            // MoE branch
             cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
             cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
@@ -13120,8 +14477,8 @@ struct llm_build_context {
         cur = inpL;
 
         cur = llm_build_norm(ctx0, cur, hparams,
-                NULL, NULL,
-                LLM_NORM, cb, -1);
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         // lm_head
@@ -13141,7 +14498,7 @@ struct llm_build_context {
 
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13268,7 +14625,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13413,7 +14770,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13554,7 +14911,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13769,7 +15126,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -13789,7 +15146,9 @@ struct llm_build_context {
             {
                 // compute Q and K and RoPE them
                 struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                Qcur = lm_ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                if (model.layers[il].wq_scale) {
+                    Qcur = lm_ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                }
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -13798,7 +15157,9 @@ struct llm_build_context {
 
                 // B1.K
                 struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                Kcur = lm_ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                if (model.layers[il].wk_scale) {
+                    Kcur = lm_ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                }
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -13807,7 +15168,9 @@ struct llm_build_context {
 
                 // B1.V
                 struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                Vcur = lm_ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                if (model.layers[il].wv_scale) {
+                    Vcur = lm_ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                }
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -13838,7 +15201,9 @@ struct llm_build_context {
                 cb(cur, "attn_sub_norm", il);
 
                 cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cur = lm_ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                if (model.layers[il].wo_scale) {
+                    cur = lm_ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                }
                 if (model.layers[il].bo) {
                     cur = lm_ggml_add(ctx0, cur, model.layers[il].bo);
                 }
@@ -13875,7 +15240,9 @@ struct llm_build_context {
             cb(cur, "ffn_sub_norm", il);
 
             cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
-            cur = lm_ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            if (model.layers[il].ffn_down_scale) {
+                cur = lm_ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            }
             cb(cur, "ffn_down", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
@@ -13893,6 +15260,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
+        // FIXME: do not use model.tok_embd directly, duplicate as model.output
         cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
         cb(cur, "result_output", -1);
 
@@ -13913,7 +15281,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         LM_GGML_ASSERT(lctx.is_encoding);
         struct lm_ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
@@ -14045,7 +15413,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         LM_GGML_ASSERT(!lctx.is_encoding);
         LM_GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
@@ -14247,7 +15615,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -14339,7 +15707,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -14453,7 +15821,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -14577,7 +15945,7 @@ struct llm_build_context {
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct lm_ggml_tensor * inp_pos = build_inp_pos();
@@ -14596,48 +15964,317 @@ struct llm_build_context {
 
             // self-attention
             {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct lm_ggml_tensor * rope_factors = build_rope_factors(il);
-
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct lm_ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    lm_ggml_cgraph * build_rwkv6() {
+        lm_ggml_cgraph *gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // Token shift state dimensions should be 2 * n_emb
+        LM_GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        LM_GGML_ASSERT(n_seqs != 0);
+        LM_GGML_ASSERT(ubatch.equal_seqs);
+        LM_GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+        struct lm_ggml_tensor * state_copy = build_inp_s_copy();
+        struct lm_ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct lm_ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct lm_ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = lm_ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = lm_ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
+
+            struct lm_ggml_tensor * att_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            struct lm_ggml_tensor * ffn_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * lm_ggml_element_size(token_shift));
+
+            struct lm_ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
+            struct lm_ggml_tensor * x_prev = lm_ggml_concat(
+                ctx0,
+                att_shift,
+                lm_ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            cur = lm_ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            lm_ggml_build_forward_expand(gf, cur);
+            lm_ggml_build_forward_expand(
+                gf,
+                lm_ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    lm_ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * lm_ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            struct lm_ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
+            x_prev = lm_ggml_concat(
+                ctx0,
+                ffn_shift,
+                lm_ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
+                1
+            );
+            cur = lm_ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
+            lm_ggml_build_forward_expand(gf, cur);
+
+            struct lm_ggml_tensor * last_norm_att = lm_ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*lm_ggml_element_size(x_norm_att));
+            struct lm_ggml_tensor * last_norm_ffn = lm_ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*lm_ggml_element_size(x_norm_ffn));
+
+            token_shift = lm_ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
+
+            lm_ggml_build_forward_expand(
+                gf,
+                lm_ggml_cpy(
+                    ctx0,
+                    lm_ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
+                    lm_ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * lm_ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+                cur = lm_ggml_scale(ctx0, cur, 0.5F);
+            }
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // ref: https://github.com/facebookresearch/chameleon
+    // based on the original build_llama() function, changes:
+    //   * qk-norm
+    //   * swin-norm
+    //   * removed bias
+    //   * removed MoE
+    struct lm_ggml_cgraph * build_chameleon() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct lm_ggml_tensor * inpSA = inpL;
+
+            // norm
+            if (hparams.swin_norm) {
+                cur = inpL;
+            } else {
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
                 // compute Q and K and RoPE them
                 struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
 
                 struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
 
                 struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = lm_ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                                lm_ggml_element_size(Qcur) * n_embd_head,
+                                lm_ggml_element_size(Qcur) * n_embd_head * n_head,
+                                0);
+                    cb(Qcur, "Qcur", il);
+
+                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                                model.layers[il].attn_q_norm,
+                                model.layers[il].attn_q_norm_b,
+                                LLM_NORM, cb, il);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = lm_ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                                lm_ggml_element_size(Kcur) * n_embd_head,
+                                lm_ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                                0);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                               model.layers[il].attn_k_norm,
+                               model.layers[il].attn_k_norm_b,
+                               LLM_NORM, cb, il);
+                    cb(Kcur, "Kcur", il);
                 }
 
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
+                        model.layers[il].wo, nullptr,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+                if (hparams.swin_norm) {
+                    cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                }
             }
 
             if (il == n_layer - 1) {
@@ -14652,10 +16289,12 @@ struct llm_build_context {
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
+            if (!hparams.swin_norm) {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+            }
 
             cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
@@ -14665,6 +16304,13 @@ struct llm_build_context {
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
+            if (hparams.swin_norm) {
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+            }
+
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -14684,6 +16330,19 @@ struct llm_build_context {
 
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output_with_img_logits", -1);
+
+        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
+        // Needs to be removed once image outputs are supported.
+        int img_token_end_idx = 8196;
+        int img_token_start_idx = 4;
+        int num_img_tokens = img_token_end_idx - img_token_start_idx;
+        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
+        // which ensures that text token values are always at least larger than image token values
+        struct lm_ggml_tensor * img_logits = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, num_img_tokens);
+        img_logits = lm_ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
+        cb(img_logits, "img_logits", -1);
+        cur = lm_ggml_set_1d(ctx0, cur, img_logits, lm_ggml_element_size(cur) * img_token_start_idx);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -14728,7 +16387,7 @@ static struct lm_ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
 
 static struct lm_ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
                   bool   worst_case) {
     const auto & model = lctx.model;
 
@@ -14743,20 +16402,21 @@ static struct lm_ggml_cgraph * llama_build_graph(
         if (!lctx.cparams.offload_kqv) {
             if (strcmp(name, "kqv_merged_cont") == 0) {
                 // all nodes between the KV store and the attention output are run on the CPU
-                lm_ggml_backend_sched_set_tensor_backend(lctx.sched, cur, lctx.backend_cpu);
+                lm_ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, lctx.backend_cpu);
             }
         }
 
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in lm_ggml_backend_sched
         const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
-        if (batch.n_tokens < 32 || full_offload) {
+        if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
-                for (auto * backend : lctx.backends) {
-                    if (lm_ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) &&
-                        (lm_ggml_backend_supports_op(backend, cur) || lm_ggml_backend_offload_op(backend, cur))) {
-                        lm_ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
-                        break;
+                const auto & dev_layer = lctx.model.dev_layer.at(il);
+                for (auto & backend : lctx.backends) {
+                    if (lm_ggml_backend_get_device(backend.get()) == dev_layer.dev) {
+                        if (lm_ggml_backend_supports_op(backend.get(), cur)) {
+                            lm_ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
+                        }
                     }
                 }
             }
@@ -14765,12 +16425,14 @@ static struct lm_ggml_cgraph * llama_build_graph(
 
     struct lm_ggml_cgraph * result = NULL;
 
-    struct llm_build_context llm(lctx, batch, cb, worst_case);
+    struct llm_build_context llm(lctx, ubatch, cb, worst_case);
 
     llm.init();
 
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
             {
                 result = llm.build_llama();
             } break;
@@ -14856,6 +16518,10 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_minicpm();
             } break;
+        case LLM_ARCH_MINICPM3:
+            {
+                result = llm.build_minicpm3();
+            } break;
         case LLM_ARCH_GEMMA:
             {
                 result = llm.build_gemma();
@@ -14888,6 +16554,10 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OLMOE:
+            {
+                result = llm.build_olmoe();
+            } break;
         case LLM_ARCH_OPENELM:
             {
                 result = llm.build_openelm();
@@ -14936,6 +16606,14 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_exaone();
             } break;
+        case LLM_ARCH_RWKV6:
+            {
+                result = llm.build_rwkv6();
+            } break;
+        case LLM_ARCH_CHAMELEON:
+            {
+                result = llm.build_chameleon();
+            } break;
         default:
             LM_GGML_ABORT("fatal error");
     }
@@ -14998,7 +16676,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
     return relative_bucket;
 }
 
-static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
+static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
     //
     // set input data
     //
@@ -15007,28 +16685,28 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     const auto & cparams = lctx.cparams;
     const auto & kv_self = lctx.kv_self;
 
-    if (batch.token) {
-        const int64_t n_tokens = batch.n_tokens;
+    if (ubatch.token) {
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        lm_ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*lm_ggml_element_size(lctx.inp_tokens));
+        lm_ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*lm_ggml_element_size(lctx.inp_tokens));
     }
 
-    if (batch.embd) {
+    if (ubatch.embd) {
         const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        lm_ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*lm_ggml_element_size(lctx.inp_embd));
+        lm_ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*lm_ggml_element_size(lctx.inp_embd));
     }
 
-    if (batch.pos && lctx.inp_pos) {
-        const int64_t n_tokens = batch.n_tokens;
+    if (ubatch.pos && lctx.inp_pos) {
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        lm_ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*lm_ggml_element_size(lctx.inp_pos));
+        lm_ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*lm_ggml_element_size(lctx.inp_pos));
     }
 
     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
         LM_GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
         int32_t * data = (int32_t *) lctx.inp_out_ids->data;
@@ -15037,10 +16715,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int i = 0; i < n_tokens; ++i) {
                 data[i] = i;
             }
-        } else if (batch.output) {
+        } else if (ubatch.output) {
             int32_t n_outputs = 0;
             for (int i = 0; i < n_tokens; ++i) {
-                if (batch.output[i]) {
+                if (ubatch.output[i]) {
                     data[n_outputs++] = i;
                 }
             }
@@ -15065,9 +16743,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
         if (cparams.causal_attn && !lctx.is_encoding) {
             const int64_t n_kv         = kv_self.n;
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
+            const int64_t n_tokens     = ubatch.n_tokens;
+            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+            const int64_t n_seqs       = ubatch.n_seqs;
 
 
             float * data     = nullptr;
@@ -15084,14 +16762,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             }
 
             // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the batch.
+            // of the correct sequence for each token of the ubatch.
             // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
             for (int h = 0; h < 1; ++h) {
                 for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = batch.seq_id[s][0];
+                    const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
                     for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = batch.pos[s*n_seq_tokens + j];
+                        const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
 
                         for (int i = 0; i < n_kv; ++i) {
                             float f;
@@ -15137,9 +16815,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
                 }
             }
         } else {
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
+            const int64_t n_tokens     = ubatch.n_tokens;
+            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+            const int64_t n_seqs       = ubatch.n_seqs;
             // when using kv cache, the mask needs to match the kv cache size
             const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
@@ -15149,7 +16827,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
             for (int h = 0; h < 1; ++h) {
                 for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = batch.seq_id[s1][0];
+                    const llama_seq_id seq_id = ubatch.seq_id[s1][0];
 
                     for (int j = 0; j < n_seq_tokens; ++j) {
                         const int32_t tj = s1*n_seq_tokens + j;
@@ -15159,10 +16837,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
                                 const int32_t ti = s0*n_seq_tokens + i;
                                 float f = -INFINITY;
 
-                                for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
-                                    if (batch.seq_id[s0][s] == seq_id) {
+                                for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
+                                    if (ubatch.seq_id[s0][s] == seq_id) {
                                         if (hparams.use_alibi) {
-                                            f = -std::abs(batch.pos[ti] - batch.pos[tj]);
+                                            f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
                                         } else {
                                             f = 0.0f;
                                         }
@@ -15184,9 +16862,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         LM_GGML_ASSERT(lctx.inp_mean);
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -15197,12 +16875,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         std::vector sum(n_tokens, 0);
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
             LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
 
-            sum[seq_id] += batch.n_seq_tokens;
+            sum[seq_id] += ubatch.n_seq_tokens;
         }
 
         std::vector div(n_tokens, 0.0f);
@@ -15214,7 +16892,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         }
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
             for (int i = 0; i < n_seq_tokens; ++i) {
                 data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
@@ -15222,10 +16900,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         }
     }
 
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+    if (cparams.embeddings && (
+                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
+                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         LM_GGML_ASSERT(lctx.inp_cls);
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -15234,13 +16914,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
-            LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
+            LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
 
             for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
 
                 if (pos == 0) {
                     data[seq_id] = s*n_seq_tokens + i;
@@ -15250,9 +16930,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         LM_GGML_ASSERT(lctx.inp_cls);
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -15264,13 +16944,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         std::vector last_row(n_tokens, -1);
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
             LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
 
             for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
 
                 if (pos >= last_pos[seq_id]) {
                     last_pos[seq_id] = pos;
@@ -15295,7 +16975,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
             // clear unused states
             for (int i = 0; i < n_kv; ++i) {
-                uint32_t        cell_id = i + kv_self.head;
+                const uint32_t  cell_id = i + kv_self.head;
                 llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
                 data[i] = (float) (kv_cell.src >= 0);
@@ -15332,10 +17012,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (lctx.inp_pos_bucket) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
-        LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
+        LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
 
         int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
 
@@ -15344,7 +17024,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int h = 0; h < 1; ++h) {
                 for (int j = 0; j < n_tokens; ++j) {
                     for (int i = 0; i < n_kv; ++i) {
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
                     }
                 }
             }
@@ -15352,7 +17032,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int h = 0; h < 1; ++h) {
                 for (int j = 0; j < n_tokens; ++j) {
                     for (int i = 0; i < n_tokens; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
                     }
                 }
             }
@@ -15368,10 +17048,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
     if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
         const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
-        LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
+        LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
 
         float * data = (float *) lctx.inp_KQ_mask_cross->data;
 
@@ -15379,8 +17059,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int j = 0; j < n_tokens; ++j) {
                 for (int i = 0; i < n_output_enc; ++i) {
                     float f = -INFINITY;
-                    for (int s = 0; s < batch.n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = batch.seq_id[j][s];
+                    for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = ubatch.seq_id[j][s];
                         if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
                             f = 0.0f;
                         }
@@ -15422,7 +17102,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
         lctx.output_ids.resize(n_batch);
     }
 
-    const size_t prev_size = lctx.buf_output ? lm_ggml_backend_buffer_get_size(lctx.buf_output) : 0;
+    const size_t prev_size = lctx.buf_output ? lm_ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
     const size_t new_size  = (logits_size + embd_size) * sizeof(float);
 
     // alloc only when more than the current capacity is required
@@ -15433,20 +17113,26 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
             // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
             LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
-            lm_ggml_backend_buffer_free(lctx.buf_output);
             lctx.buf_output = nullptr;
             lctx.logits = nullptr;
             lctx.embd = nullptr;
         }
 
-        lctx.buf_output = lm_ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
+        auto * buft = lm_ggml_backend_cpu_buffer_type();
+        // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
+        auto * output_dev = lctx.model.dev_output.dev;
+        auto * output_dev_host_buft = output_dev ? lm_ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
+        if (output_dev_host_buft) {
+            buft = output_dev_host_buft;
+        }
+        lctx.buf_output.reset(lm_ggml_backend_buft_alloc_buffer(buft, new_size));
         if (lctx.buf_output == nullptr) {
             LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
             return 0;
         }
     }
 
-    float * output_base = (float *) lm_ggml_backend_buffer_get_base(lctx.buf_output);
+    float * output_base = (float *) lm_ggml_backend_buffer_get_base(lctx.buf_output.get());
 
     lctx.logits = has_logits ? output_base               : nullptr;
     lctx.embd   = has_embd   ? output_base + logits_size : nullptr;
@@ -15458,7 +17144,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     // set all ids as invalid (negative)
     std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
 
-    lm_ggml_backend_buffer_clear(lctx.buf_output, 0);
+    lm_ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
 
     lctx.n_outputs = 0;
 
@@ -15504,26 +17190,24 @@ static void llama_output_reorder(struct llama_context * ctx) {
 }
 
 static void llama_graph_compute(
-        llama_context & lctx,
-          lm_ggml_cgraph * gf,
-                  int   n_threads) {
-#ifdef LM_GGML_USE_METAL
-    if (lm_ggml_backend_is_metal(lctx.backend_metal)) {
-        lm_ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
-    }
-#endif
-
+          llama_context & lctx,
+            lm_ggml_cgraph * gf,
+                    int   n_threads,
+        lm_ggml_threadpool * threadpool) {
     if (lctx.backend_cpu != nullptr) {
-        lm_ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
+        lm_ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
         lm_ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
-#ifdef LM_GGML_USE_BLAS
-    if (lctx.backend_blas != nullptr) {
-        lm_ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
+
+    // set the number of threads for all the backends
+    for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
+        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
     }
-#endif
 
-    lm_ggml_backend_sched_graph_compute_async(lctx.sched, gf);
+    auto err = lm_ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
+    if (err != LM_GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: lm_ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
+    }
 
     // fprintf(stderr, "splits: %d\n", lm_ggml_backend_sched_get_n_splits(lctx.sched));
 }
@@ -15539,21 +17223,34 @@ static void llama_graph_compute(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch_all) { // TODO: rename back to batch
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = false;
-    const uint32_t n_tokens_all = batch_all.n_tokens;
 
-    if (n_tokens_all == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens_all = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    LM_GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
+    LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+
+    if (batch.token) {
+        for (uint32_t i = 0; i < n_tokens_all; ++i) {
+            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
+        }
+    }
 
     LM_GGML_ASSERT(n_tokens_all <= cparams.n_batch);
 
@@ -15580,9 +17277,9 @@ static int llama_decode_internal(
     lctx.embd_seq.clear();
 
     // count outputs
-    if (batch_all.logits && !embd_pooled) {
+    if (batch.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch_all.logits[i] != 0;
+            n_outputs += batch.logits[i] != 0;
         }
     } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
@@ -15591,7 +17288,7 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
-    lctx.sbatch.from_batch(batch_all, n_embd,
+    lctx.sbatch.from_batch(batch, n_embd,
         /* simple_split */ !kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
@@ -15635,6 +17332,8 @@ static int llama_decode_internal(
         }
 
         int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        lm_ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+
         LM_GGML_ASSERT(n_threads > 0);
 
         // non-causal masks do not use the KV cache
@@ -15663,14 +17362,14 @@ static int llama_decode_internal(
 
         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
-        lm_ggml_backend_sched_reset(lctx.sched);
-        lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+        lm_ggml_backend_sched_reset(lctx.sched.get());
+        lm_ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
         lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
         // the output is always the last tensor in the graph
-        struct lm_ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
-        struct lm_ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
+        struct lm_ggml_tensor * res  = lm_ggml_graph_node(gf, -1);
+        struct lm_ggml_tensor * embd = lm_ggml_graph_node(gf, -2);
 
         if (lctx.n_outputs == 0) {
             // no output
@@ -15679,9 +17378,9 @@ static int llama_decode_internal(
         } else if (cparams.embeddings) {
             res  = nullptr; // do not extract logits for embedding case
             embd = nullptr;
-            for (int i = gf->n_nodes - 1; i >= 0; --i) {
-                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
-                    embd = gf->nodes[i];
+            for (int i = lm_ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+                if (strcmp(lm_ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
+                    embd = lm_ggml_graph_node(gf, i);
                     break;
                 }
             }
@@ -15692,11 +17391,11 @@ static int llama_decode_internal(
         }
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-        lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
+        lm_ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
         llama_set_inputs(lctx, ubatch);
 
-        llama_graph_compute(lctx, gf, n_threads);
+        llama_graph_compute(lctx, gf, n_threads, threadpool);
 
         // update the kv ring buffer
         {
@@ -15715,7 +17414,7 @@ static int llama_decode_internal(
 
         // extract logits
         if (res) {
-            lm_ggml_backend_t backend_res = lm_ggml_backend_sched_get_tensor_backend(lctx.sched, res);
+            lm_ggml_backend_t backend_res = lm_ggml_backend_sched_get_tensor_backend(lctx.sched.get(), res);
             LM_GGML_ASSERT(backend_res != nullptr);
             LM_GGML_ASSERT(lctx.logits != nullptr);
 
@@ -15731,7 +17430,7 @@ static int llama_decode_internal(
 
         // extract embeddings
         if (embd) {
-            lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+            lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
             LM_GGML_ASSERT(backend_embd != nullptr);
 
             switch (cparams.pooling_type) {
@@ -15764,6 +17463,20 @@ static int llama_decode_internal(
                             lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // extract the rerank score - a single float per sequence
+                        auto & embd_seq_out = lctx.embd_seq;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(1);
+                            lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        }
+                    } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         LM_GGML_ABORT("unknown pooling type");
@@ -15812,7 +17525,7 @@ static int llama_decode_internal(
 
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
-    lm_ggml_backend_sched_reset(lctx.sched);
+    lm_ggml_backend_sched_reset(lctx.sched.get());
 
     return 0;
 }
@@ -15828,23 +17541,35 @@ static int llama_decode_internal(
 //
 static int llama_encode_internal(
          llama_context & lctx,
-           llama_batch   batch) {
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = true;
 
-    const uint32_t n_tokens = batch.n_tokens;
-
-    if (n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
     LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    if (batch.token) {
+        for (uint32_t i = 0; i < n_tokens; ++i) {
+            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
+        }
+    }
+
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
     LM_GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
 
@@ -15873,11 +17598,13 @@ static int llama_encode_internal(
     lctx.inp_embd_enc = NULL;
     lctx.n_outputs = n_tokens;
 
-    const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    lm_ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+
     LM_GGML_ASSERT(n_threads > 0);
 
-    lm_ggml_backend_sched_reset(lctx.sched);
-    lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+    lm_ggml_backend_sched_reset(lctx.sched.get());
+    lm_ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
     lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
@@ -15887,29 +17614,29 @@ static int llama_encode_internal(
     // there are two cases here
     if (llama_model_has_decoder(&lctx.model)) {
         // first case is an encoder-decoder T5 model where embeddings are passed to decoder
-        embd = gf->nodes[gf->n_nodes - 1];
+        embd = lm_ggml_graph_node(gf, -1);
         LM_GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
     } else {
         // second case is an encoder-only T5 model
         if (cparams.embeddings) {
             // only output embeddings if required
-            embd = gf->nodes[gf->n_nodes - 1];
+            embd = lm_ggml_graph_node(gf, -1);
             if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+                embd = lm_ggml_graph_node(gf, -2);
             }
             LM_GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
         }
     }
 
-    lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
+    lm_ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
     llama_set_inputs(lctx, ubatch);
 
-    llama_graph_compute(lctx, gf, n_threads);
+    llama_graph_compute(lctx, gf, n_threads, threadpool);
 
     // extract embeddings
     if (embd) {
-        lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+        lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
         LM_GGML_ASSERT(backend_embd != nullptr);
 
         if (llama_model_has_decoder(&lctx.model)) {
@@ -15959,6 +17686,13 @@ static int llama_encode_internal(
                             lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
+                        //       wait for an encoder model that requires this pooling type in order to test it
+                        //       https://github.com/ggerganov/llama.cpp/pull/9510
+                        LM_GGML_ABORT("RANK pooling not implemented yet");
+                    }
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         LM_GGML_ABORT("unknown pooling type");
@@ -15969,7 +17703,7 @@ static int llama_encode_internal(
 
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
-    lm_ggml_backend_sched_reset(lctx.sched);
+    lm_ggml_backend_sched_reset(lctx.sched.get());
 
     return 0;
 }
@@ -16183,11 +17917,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 #else
     // lm_ggml_graph defrag
 
-    lm_ggml_backend_sched_reset(lctx.sched);
+    lm_ggml_backend_sched_reset(lctx.sched.get());
 
     lm_ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
 
-    llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+    llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
 #endif
 
     //const int64_t t_end = lm_ggml_time_us();
@@ -16205,15 +17939,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         }
 
         {
-            lm_ggml_backend_sched_reset(lctx.sched);
+            lm_ggml_backend_sched_reset(lctx.sched.get());
 
             lm_ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
 
-            lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
+            lm_ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
             llama_set_k_shift(lctx);
 
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+            llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
 
             need_reserve = true;
         }
@@ -16249,8 +17983,8 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
         // initialize scheduler with the worst-case graph
-        lm_ggml_backend_sched_reset(lctx.sched);
-        if (!lm_ggml_backend_sched_reserve(lctx.sched, gf)) {
+        lm_ggml_backend_sched_reset(lctx.sched.get());
+        if (!lm_ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
             LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
         }
     }
@@ -16296,10 +18030,9 @@ static void llama_tensor_dequantize_internal(
     }
     float * f32_output = (float *) output.data();
 
-    lm_ggml_type_traits_t qtype;
+    const lm_ggml_type_traits * qtype = lm_ggml_get_type_traits(tensor->type);
     if (lm_ggml_is_quantized(tensor->type)) {
-        qtype = lm_ggml_internal_get_type_traits(tensor->type);
-        if (qtype.to_float == NULL) {
+        if (qtype->to_float == NULL) {
             throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", lm_ggml_type_name(tensor->type)));
         }
     } else if (tensor->type != LM_GGML_TYPE_F16 &&
@@ -16313,7 +18046,7 @@ static void llama_tensor_dequantize_internal(
         } else if (tensor->type == LM_GGML_TYPE_BF16) {
             lm_ggml_bf16_to_fp32_row((lm_ggml_bf16_t *)tensor->data, f32_output, nelements);
         } else if (lm_ggml_is_quantized(tensor->type)) {
-            qtype.to_float(tensor->data, f32_output, nelements);
+            qtype->to_float(tensor->data, f32_output, nelements);
         } else {
             LM_GGML_ABORT("fatal error"); // unreachable
         }
@@ -16349,7 +18082,7 @@ static void llama_tensor_dequantize_internal(
             } else if (typ == LM_GGML_TYPE_BF16) {
                 lm_ggml_bf16_to_fp32_row((lm_ggml_bf16_t *)inbuf, outbuf, nels);
             } else {
-                qtype.to_float(inbuf, outbuf, nels);
+                qtype->to_float(inbuf, outbuf, nels);
             }
         };
         workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
@@ -16424,6 +18157,9 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_
                      new_type == LM_GGML_TYPE_Q4_0_8_8) {
                 new_type = LM_GGML_TYPE_Q4_0;
             }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
+                new_type = LM_GGML_TYPE_Q4_K;
+            }
         }
     } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
@@ -16623,6 +18359,8 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_
     }
     if (convert_incompatible_tensor) {
         switch (new_type) {
+            case LM_GGML_TYPE_TQ1_0:
+            case LM_GGML_TYPE_TQ2_0:  new_type = LM_GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
             case LM_GGML_TYPE_IQ2_XXS:
             case LM_GGML_TYPE_IQ2_XS:
             case LM_GGML_TYPE_IQ2_S:
@@ -16728,6 +18466,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_Q5_K_S:
         case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = LM_GGML_TYPE_Q5_K;    break;
         case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = LM_GGML_TYPE_Q6_K;    break;
+        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = LM_GGML_TYPE_TQ1_0;   break;
+        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = LM_GGML_TYPE_TQ2_0;   break;
         case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = LM_GGML_TYPE_IQ2_XXS; break;
         case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = LM_GGML_TYPE_IQ2_XS;  break;
         case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = LM_GGML_TYPE_IQ2_XS;  break;
@@ -16795,44 +18535,62 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     }
 
     const size_t align = LM_GGUF_DEFAULT_ALIGNMENT;
-    struct lm_gguf_context * ctx_out = lm_gguf_init_empty();
+    lm_gguf_context_ptr ctx_out { lm_gguf_init_empty() };
 
     // copy the KV pairs from the input file
-    lm_gguf_set_kv     (ctx_out, ml.meta);
-    lm_gguf_set_val_u32(ctx_out, "general.quantization_version", LM_GGML_QNT_VERSION); // TODO: use LLM_KV
-    lm_gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV
+    lm_gguf_set_kv     (ctx_out.get(), ml.meta.get());
+    lm_gguf_set_val_u32(ctx_out.get(), "general.quantization_version", LM_GGML_QNT_VERSION); // TODO: use LLM_KV
+    lm_gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
 
     // Remove split metadata
-    lm_gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
-    lm_gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
-    lm_gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
+    lm_gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
+    lm_gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
+    lm_gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
 
     if (params->kv_overrides) {
         const std::vector & overrides = *(const std::vector *)params->kv_overrides;
-        for (auto & o : overrides) {
+        for (const auto & o : overrides) {
             if (o.key[0] == 0) break;
             if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
-                lm_gguf_set_val_f32(ctx_out, o.key, o.val_f64);
+                lm_gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
-                lm_gguf_set_val_i32(ctx_out, o.key, o.val_i64);
+                lm_gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
-                lm_gguf_set_val_bool(ctx_out, o.key, o.val_bool);
+                lm_gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
-                lm_gguf_set_val_str(ctx_out, o.key, o.val_str);
+                lm_gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
             } else {
                 LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
             }
         }
     }
 
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        const struct lm_ggml_tensor * meta = ml.get_tensor_meta(i);
+    // make a list of weights
+    std::vector tensors;
+    tensors.reserve(ml.weights_map.size());
+    for (const auto & it : ml.weights_map) {
+        tensors.push_back(&it.second);
+    }
+
+    // keep_split requires that the weights are sorted by split index
+    if (params->keep_split) {
+        std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
+            if (a->idx == b->idx) {
+                return a->offs < b->offs;
+            }
+            return a->idx < b->idx;
+        });
+    }
+
+    for (const auto * it : tensors) {
+        const struct lm_ggml_tensor * tensor = it->tensor;
 
-        const std::string name = lm_ggml_get_name(meta);
+        const std::string name = lm_ggml_get_name(tensor);
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
         if (name.find("attn_v.weight")   != std::string::npos ||
-            name.find("attn_qkv.weight") != std::string::npos) {
+            name.find("attn_qkv.weight") != std::string::npos ||
+            name.find("attn_kv_b.weight")!= std::string::npos) {
             ++qs.n_attention_wv;
         } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
             qs.has_output = true;
@@ -16865,32 +18623,32 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     std::vector> f32_conv_buf;
 
     uint16_t n_split = 1;
+
     // Assume split index is continuous
     if (params->keep_split) {
-        for (int i = 0; i < ml.n_tensors; ++i) {
-            n_split = std::max(uint16_t(ml.get_weight(i)->idx+1), n_split);
+        for (const auto * it : tensors) {
+            n_split = std::max(uint16_t(it->idx + 1), n_split);
         }
     }
-    std::vector ctx_outs(n_split, NULL);
-    ctx_outs[0] = ctx_out;
+    std::vector ctx_outs(n_split);
+    ctx_outs[0] = std::move(ctx_out);
 
     // populate the original tensors so we get an initial meta data
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        auto weight = ml.get_weight(i);
-        uint16_t i_split = params->keep_split ? weight->idx : 0;
-        struct lm_ggml_tensor * tensor = weight->tensor;
-        if (ctx_outs[i_split] == NULL) {
-            ctx_outs[i_split] = lm_gguf_init_empty();
+    for (const auto * it : tensors) {
+        uint16_t i_split = params->keep_split ? it->idx : 0;
+        struct lm_ggml_tensor * tensor = it->tensor;
+        if (!ctx_outs[i_split]) {
+            ctx_outs[i_split].reset(lm_gguf_init_empty());
         }
-        lm_gguf_add_tensor(ctx_outs[i_split], tensor);
+        lm_gguf_add_tensor(ctx_outs[i_split].get(), tensor);
     }
 
     // Set split info if needed
     if (n_split > 1) {
         for (size_t i = 0; i < ctx_outs.size(); ++i) {
-            lm_gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
-            lm_gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
-            lm_gguf_set_val_i32(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
+            lm_gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
+            lm_gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
+            lm_gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
         }
     }
 
@@ -16900,8 +18658,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // Write metadata and close file handler
         if (fout.is_open()) {
             fout.seekp(0);
-            std::vector data(lm_gguf_get_meta_size(ctx_outs[cur_split]));
-            lm_gguf_get_meta_data(ctx_outs[cur_split], data.data());
+            std::vector data(lm_gguf_get_meta_size(ctx_outs[cur_split].get()));
+            lm_gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
             fout.write((const char *) data.data(), data.size());
             fout.close();
         }
@@ -16918,19 +18676,19 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         fout = std::ofstream(fname, std::ios::binary);
         fout.exceptions(std::ofstream::failbit); // fail fast on write errors
-        const size_t meta_size = lm_gguf_get_meta_size(ctx_outs[cur_split]);
+        const size_t meta_size = lm_gguf_get_meta_size(ctx_outs[cur_split].get());
         // placeholder for the meta data
         ::zeros(fout, meta_size);
     };
 
     const auto tn = LLM_TN(model.arch);
     new_ofstream(0);
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        auto weight = ml.get_weight(i);
-        struct lm_ggml_tensor * tensor = weight->tensor;
-        if (weight->idx != cur_split && params->keep_split) {
+    for (const auto * it : tensors) {
+        const auto & weight = *it;
+        struct lm_ggml_tensor * tensor = weight.tensor;
+        if (weight.idx != cur_split && params->keep_split) {
             close_ofstream();
-            new_ofstream(weight->idx);
+            new_ofstream(weight.idx);
         }
 
         const std::string name = lm_ggml_get_name(tensor);
@@ -16973,6 +18731,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
 
+        // do not quantize RWKV's time_mix_first tensors
+        quantize &= name.find("time_mix_first.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
 
@@ -17096,17 +18861,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         total_size_new += new_size;
 
         // update the gguf meta data as we go
-        lm_gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type);
-        lm_gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size);
+        lm_gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
+        lm_gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
 
         // write tensor data + padding
         fout.write((const char *) new_data, new_size);
         zeros(fout, LM_GGML_PAD(new_size, align) - new_size);
     }
     close_ofstream();
-    for (auto & c:ctx_outs) {
-        lm_gguf_free(c);
-    }
 
     LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
     LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
@@ -17120,55 +18882,55 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
     LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
-    lm_ggml_context * ctx = nullptr;
+    lm_ggml_context * ctx_init;
     struct lm_gguf_init_params meta_lm_gguf_params = {
         /* .no_alloc = */ true,
-        /* .ctx      = */ &ctx,
+        /* .ctx      = */ &ctx_init,
     };
-    struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(path_lora, meta_lm_gguf_params);
+
+    lm_gguf_context_ptr ctx_gguf { lm_gguf_init_from_file(path_lora, meta_lm_gguf_params) };
     if (!ctx_gguf) {
         throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
     }
 
+    lm_ggml_context_ptr ctx { ctx_init };
+
     // check metadata
     {
         auto get_kv_str = [&](const std::string & key) -> std::string {
-            int id = lm_gguf_find_key(ctx_gguf, key.c_str());
-            return id < 0 ? "" : std::string(lm_gguf_get_val_str(ctx_gguf, id));
+            int id = lm_gguf_find_key(ctx_gguf.get(), key.c_str());
+            return id < 0 ? "" : std::string(lm_gguf_get_val_str(ctx_gguf.get(), id));
         };
         auto get_kv_f32 = [&](const std::string & key) -> float {
-            int id = lm_gguf_find_key(ctx_gguf, key.c_str());
-            return id < 0 ? 0.0f : lm_gguf_get_val_f32(ctx_gguf, id);
+            int id = lm_gguf_find_key(ctx_gguf.get(), key.c_str());
+            return id < 0 ? 0.0f : lm_gguf_get_val_f32(ctx_gguf.get(), id);
         };
         LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
 
         auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
         if (general_type != "adapter") {
-            lm_gguf_free(ctx_gguf);
             throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
         }
 
         auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
         auto general_arch = llm_arch_from_string(general_arch_str);
         if (general_arch != model->arch) {
-            lm_gguf_free(ctx_gguf);
             throw std::runtime_error("model arch and LoRA arch mismatch");
         }
 
         auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
         if (adapter_type != "lora") {
-            lm_gguf_free(ctx_gguf);
             throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
         }
 
         adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
     }
 
-    int n_tensors = lm_gguf_get_n_tensors(ctx_gguf);
+    int n_tensors = lm_gguf_get_n_tensors(ctx_gguf.get());
 
     // contexts for each buffer type
     std::map ctx_map;
-    auto get_ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * {
+    auto ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             // add a new context
@@ -17178,7 +18940,11 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
                 /*.no_alloc   =*/ true,
             };
             lm_ggml_context * buft_ctx = lm_ggml_init(params);
+            if (!buft_ctx) {
+                return nullptr;
+            }
             ctx_map[buft] = buft_ctx;
+            adapter.ctxs.emplace_back(buft_ctx);
             return buft_ctx;
         };
         return it->second;
@@ -17189,7 +18955,7 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
     auto str_endswith = [](const std::string & str, const std::string & suffix) {
         return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
     };
-    for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) {
+    for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx.get()); cur; cur = lm_ggml_get_next_tensor(ctx.get(), cur)) {
         std::string name(cur->name);
         if (str_endswith(name, ".lora_a")) {
             replace_all(name, ".lora_a", "");
@@ -17206,8 +18972,6 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
                 ab_map[name].b = cur;
             }
         } else {
-            lm_gguf_free(ctx_gguf);
-            lm_ggml_free(ctx);
             throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
         }
     }
@@ -17218,28 +18982,20 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
         llama_lora_weight & w = it.second;
 
         if (!w.a || !w.b) {
-            lm_gguf_free(ctx_gguf);
-            lm_ggml_free(ctx);
             throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
         }
 
         // device buft and device ctx
         auto * model_tensor = llama_get_model_tensor(model, name.c_str());
         if (!model_tensor) {
-            lm_gguf_free(ctx_gguf);
-            lm_ggml_free(ctx);
             throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
         }
-        struct lm_ggml_context * dev_ctx = get_ctx_for_buft(lm_ggml_backend_buffer_get_type(model_tensor->buffer));
+        struct lm_ggml_context * dev_ctx = ctx_for_buft(lm_ggml_backend_buffer_get_type(model_tensor->buffer));
         // validate tensor shape
         if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
-            lm_gguf_free(ctx_gguf);
-            lm_ggml_free(ctx);
             throw std::runtime_error("tensor '" + name + "' has incorrect shape");
         }
         if (w.a->ne[1] != w.b->ne[0]) {
-            lm_gguf_free(ctx_gguf);
-            lm_ggml_free(ctx);
             throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
         }
         // save tensor to adapter
@@ -17254,18 +19010,15 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
     {
         adapter.ctxs.reserve(ctx_map.size());
         adapter.bufs.reserve(ctx_map.size());
-        for (auto it : ctx_map) {
+        for (auto & it : ctx_map) {
             lm_ggml_backend_buffer_type_t buft = it.first;
             lm_ggml_context * ctx_dev = it.second;
-            lm_ggml_backend_buffer_t buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
+            lm_ggml_backend_buffer_ptr buf { lm_ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) };
             if (!buf) {
-                lm_gguf_free(ctx_gguf);
-                lm_ggml_free(ctx);
                 throw std::runtime_error("failed to allocate buffer for lora adapter\n");
             }
-            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-            adapter.ctxs.push_back(ctx_dev);
-            adapter.bufs.push_back(buf);
+            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf.get()), lm_ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0);
+            adapter.bufs.emplace_back(std::move(buf));
         }
     }
 
@@ -17274,7 +19027,7 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
         llama_file lm_gguf_file(path_lora, "rb");
         std::vector read_buf;
         auto set_tensor = [&](struct lm_ggml_tensor * orig, struct lm_ggml_tensor * dev) {
-            size_t offs = lm_gguf_get_data_offset(ctx_gguf) + lm_gguf_get_tensor_offset(ctx_gguf, lm_gguf_find_tensor(ctx_gguf, orig->name));
+            size_t offs = lm_gguf_get_data_offset(ctx_gguf.get()) + lm_gguf_get_tensor_offset(ctx_gguf.get(), lm_gguf_find_tensor(ctx_gguf.get(), orig->name));
             size_t size = lm_ggml_nbytes(orig);
             read_buf.resize(size);
             lm_gguf_file.seek(offs, SEEK_SET);
@@ -17289,11 +19042,7 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
         }
     }
 
-    LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
-
-    // free ctx for reading gguf
-    lm_gguf_free(ctx_gguf);
-    lm_ggml_free(ctx);
+    LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 }
 
 int32_t llama_lora_adapter_set(
@@ -17348,7 +19097,9 @@ struct llama_model_params llama_model_default_params() {
 
 #ifdef LM_GGML_USE_METAL
     // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
+    if (result.n_gpu_layers > 0) {
+        result.n_gpu_layers = 999;
+    }
 #endif
 
     return result;
@@ -17356,7 +19107,6 @@ struct llama_model_params llama_model_default_params() {
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
-        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 2048,
         /*.n_ubatch                    =*/ 512,
@@ -17382,6 +19132,7 @@ struct llama_context_params llama_context_default_params() {
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
+        /*.no_perf                     =*/ true,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
@@ -17389,6 +19140,14 @@ struct llama_context_params llama_context_default_params() {
     return result;
 }
 
+struct llama_sampler_chain_params llama_sampler_chain_default_params() {
+    struct llama_sampler_chain_params result = {
+        /*.no_perf                     =*/ true,
+    };
+
+    return result;
+}
+
 struct llama_model_quantize_params llama_model_quantize_default_params() {
     struct llama_model_quantize_params result = {
         /*.nthread                     =*/ 0,
@@ -17408,21 +19167,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
 }
 
 size_t llama_max_devices(void) {
-#if defined(LM_GGML_USE_RPC)
-    return LM_GGML_RPC_MAX_SERVERS;
-#elif defined(LM_GGML_USE_METAL)
-    return 1;
-#elif defined(LM_GGML_USE_CUDA)
-    return LM_GGML_CUDA_MAX_DEVICES;
-#elif defined(LM_GGML_USE_SYCL)
-    return LM_GGML_SYCL_MAX_DEVICES;
-#elif defined(LM_GGML_USE_VULKAN)
-    return LM_GGML_VK_MAX_DEVICES;
-#elif defined(LM_GGML_USE_CANN)
-    return LM_GGML_CANN_MAX_DEVICES;
-#else
-    return 1;
-#endif
+    return 16;
 }
 
 bool llama_supports_mmap(void) {
@@ -17434,13 +19179,12 @@ bool llama_supports_mlock(void) {
 }
 
 bool llama_supports_gpu_offload(void) {
-#if defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_METAL)   || defined(LM_GGML_USE_VULKAN) || \
-    defined(LM_GGML_USE_SYCL) || defined(LM_GGML_USE_KOMPUTE) || defined(LM_GGML_USE_RPC)
-    // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
-    return true;
-#else
-    return false;
-#endif
+    return lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
+           llama_supports_rpc();
+}
+
+bool llama_supports_rpc(void) {
+    return lm_ggml_backend_reg_by_name("RPC") != nullptr;
 }
 
 void llama_backend_init(void) {
@@ -17460,6 +19204,19 @@ void llama_numa_init(enum lm_ggml_numa_strategy numa) {
     }
 }
 
+void llama_attach_threadpool(
+             struct llama_context * ctx,
+        lm_ggml_threadpool_t   threadpool,
+        lm_ggml_threadpool_t   threadpool_batch) {
+    ctx->threadpool       = threadpool;
+    ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
+}
+
+void llama_detach_threadpool(struct llama_context * ctx) {
+    ctx->threadpool       = nullptr;
+    ctx->threadpool_batch = nullptr;
+}
+
 void llama_backend_free(void) {
     lm_ggml_quantize_free();
 }
@@ -17483,25 +19240,97 @@ struct llama_model * llama_load_model_from_file(
             unsigned percentage = (unsigned) (100 * progress);
             while (percentage > *cur_percentage_p) {
                 *cur_percentage_p = percentage;
-                LLAMA_LOG_INFO(".");
+                LLAMA_LOG_CONT(".");
                 if (percentage >= 100) {
-                    LLAMA_LOG_INFO("\n");
+                    LLAMA_LOG_CONT("\n");
                 }
             }
             return true;
         };
     }
+
     if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
         // split the servers set them into model->rpc_servers
         std::string servers(params.rpc_servers);
         size_t pos = 0;
-        while ((pos = servers.find(",")) != std::string::npos) {
+        while ((pos = servers.find(',')) != std::string::npos) {
             std::string server = servers.substr(0, pos);
             model->rpc_servers.push_back(server);
             servers.erase(0, pos + 1);
         }
         model->rpc_servers.push_back(servers);
     }
+
+    // add RPC devices
+    if (!model->rpc_servers.empty()) {
+        lm_ggml_backend_reg_t rpc_reg = lm_ggml_backend_reg_by_name("RPC");
+        if (!rpc_reg) {
+            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
+            llama_free_model(model);
+            return nullptr;
+        }
+
+        typedef lm_ggml_backend_dev_t (*lm_ggml_backend_rpc_add_device_t)(const char * endpoint);
+        lm_ggml_backend_rpc_add_device_t lm_ggml_backend_rpc_add_device_fn = (lm_ggml_backend_rpc_add_device_t) lm_ggml_backend_reg_get_proc_address(rpc_reg, "lm_ggml_backend_rpc_add_device");
+        if (!lm_ggml_backend_rpc_add_device_fn) {
+            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
+            llama_free_model(model);
+            return nullptr;
+        }
+
+        for (const std::string & server : model->rpc_servers) {
+            lm_ggml_backend_dev_t dev = lm_ggml_backend_rpc_add_device_fn(server.c_str());
+            if (dev) {
+                model->devices.push_back(dev);
+            } else {
+                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
+                llama_free_model(model);
+                return nullptr;
+            }
+        }
+    }
+
+    // create list of devices to use with this model
+    // currently, we use all available devices
+    // TODO: rework API to give user more control over device selection
+    for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
+        lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
+        switch (lm_ggml_backend_dev_type(dev)) {
+            case LM_GGML_BACKEND_DEVICE_TYPE_CPU:
+            case LM_GGML_BACKEND_DEVICE_TYPE_ACCEL:
+                // skip CPU backends since they are handled separately
+                break;
+
+            case LM_GGML_BACKEND_DEVICE_TYPE_GPU:
+#ifdef LM_GGML_USE_METAL
+                if (params.n_gpu_layers > 0) {
+                    model->devices.push_back(dev);
+                }
+#else
+                model->devices.push_back(dev);
+#endif
+                break;
+        }
+    }
+
+    // if using single GPU mode, remove all except the main GPU
+    if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
+        if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
+            LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
+            llama_free_model(model);
+            return nullptr;
+        }
+        lm_ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
+        model->devices.clear();
+        model->devices.push_back(main_gpu);
+    }
+
+    for (auto * dev : model->devices) {
+        size_t free, total; // NOLINT
+        lm_ggml_backend_dev_memory(dev, &free, &total);
+        LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, lm_ggml_backend_dev_name(dev), lm_ggml_backend_dev_description(dev), free/1024/1024);
+    }
+
     int status = llama_model_load(path_model, *model, params);
     LM_GGML_ASSERT(status <= 0);
     if (status < 0) {
@@ -17510,7 +19339,7 @@ struct llama_model * llama_load_model_from_file(
         } else if (status == -2) {
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
-        delete model;
+        llama_free_model(model);
         return nullptr;
     }
 
@@ -17550,7 +19379,7 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
-    if (params.type_v != LM_GGML_TYPE_F16 && !params.flash_attn) {
+    if (lm_ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }
@@ -17571,6 +19400,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.flash_attn       = params.flash_attn;
+    cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
@@ -17629,187 +19459,83 @@ struct llama_context * llama_new_context_with_model(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
     LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
     LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
-    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
-    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
-
-    ctx->abort_callback      = params.abort_callback;
-    ctx->abort_callback_data = params.abort_callback_data;
-
-    ctx->sampling.rng = std::mt19937(params.seed);
-    ctx->logits_all   = params.logits_all;
-    // build worst-case graph for encoder if a model contains encoder
-    ctx->is_encoding  = llama_model_has_encoder(model);
-
-    uint32_t kv_size = cparams.n_ctx;
-    lm_ggml_type type_k = params.type_k;
-    lm_ggml_type type_v = params.type_v;
-
-    // Mamba only needs a constant number of KV cache cells per sequence
-    if (llama_model_is_recurrent(model)) {
-        // Mamba needs at least as many KV cells as there are sequences kept at any time
-        kv_size = std::max((uint32_t) 1, params.n_seq_max);
-        // it's probably best to keep as much precision as possible for the states
-        type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
-        type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
-    }
-
-    LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
-    LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
-
-    if (!hparams.vocab_only) {
-        // initialize backends
-#if defined(LM_GGML_USE_METAL)
-        if (model->n_gpu_layers > 0) {
-            ctx->backend_metal = lm_ggml_backend_metal_init();
-            if (ctx->backend_metal == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(ctx->backend_metal);
-        }
-#elif defined(LM_GGML_USE_CUDA)
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-            lm_ggml_backend_t backend = lm_ggml_backend_cuda_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-            for (int device = 0; device < lm_ggml_backend_cuda_get_device_count(); ++device) {
-                lm_ggml_backend_t backend = lm_ggml_backend_cuda_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(LM_GGML_USE_VULKAN)
-        if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
-            lm_ggml_backend_t backend = lm_ggml_backend_vk_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            for (int device = 0; device < lm_ggml_backend_vk_get_device_count(); ++device) {
-                lm_ggml_backend_t backend = lm_ggml_backend_vk_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(LM_GGML_USE_SYCL)
-        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            lm_ggml_backend_t backend = lm_ggml_backend_sycl_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_LAYER requires a backend for each GPU
-            for (int i = 0; i < lm_ggml_backend_sycl_get_device_count(); ++i) {
-                lm_ggml_backend_t backend = lm_ggml_backend_sycl_init(i);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(LM_GGML_USE_KOMPUTE)
-        if (model->n_gpu_layers > 0) {
-            auto * backend = lm_ggml_backend_kompute_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        }
-#elif defined(LM_GGML_USE_CANN)
-    // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-    // TODO: lm_ggml_backend_cann is not support split tensor now, just leave code here.
-    if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-        lm_ggml_backend_t backend = lm_ggml_backend_cann_init(model->main_gpu);
-        if (backend == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
-            llama_free(ctx);
-            return nullptr;
-        }
-        ctx->backends.push_back(backend);
-    } else {
-        // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-        // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
-        for (int32_t device = 0; device < lm_ggml_backend_cann_get_device_count(); ++device) {
-            lm_ggml_backend_t backend = lm_ggml_backend_cann_init(device);
+    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
+
+    ctx->abort_callback      = params.abort_callback;
+    ctx->abort_callback_data = params.abort_callback_data;
+
+    ctx->logits_all = params.logits_all;
+
+    // build worst-case graph for encoder if a model contains encoder
+    ctx->is_encoding = llama_model_has_encoder(model);
+
+    uint32_t kv_size = cparams.n_ctx;
+    lm_ggml_type type_k = params.type_k;
+    lm_ggml_type type_v = params.type_v;
+
+    // Mamba only needs a constant number of KV cache cells per sequence
+    if (llama_model_is_recurrent(model)) {
+        // Mamba needs at least as many KV cells as there are sequences kept at any time
+        kv_size = std::max((uint32_t) 1, params.n_seq_max);
+        // it's probably best to keep as much precision as possible for the states
+        type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
+        type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
+    }
+
+    LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
+    LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
+
+    if (!hparams.vocab_only) {
+        // GPU backends
+        for (auto * dev : model->devices) {
+            lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
             if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+                LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, lm_ggml_backend_dev_name(dev));
                 llama_free(ctx);
                 return nullptr;
             }
-            ctx->backends.push_back(backend);
-        }
-    }
-#endif
-
-#ifdef LM_GGML_USE_BLAS
-        ctx->backend_blas = lm_ggml_backend_blas_init();
-        if (ctx->backend_blas == nullptr) {
-            LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
-        } else {
-            ctx->backends.push_back(ctx->backend_blas);
+            ctx->backends.emplace_back(backend);
         }
-#endif
 
-#if defined(LM_GGML_USE_RPC)
-        if (model->n_gpu_layers > 0) {
-            for (const auto & endpoint : model->rpc_servers) {
-                lm_ggml_backend_t backend = lm_ggml_backend_rpc_init(endpoint.c_str());
+        // add ACCEL backends (such as BLAS)
+        for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
+            lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
+            if (lm_ggml_backend_dev_type(dev) == LM_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
+                lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
                 if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
+                    LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, lm_ggml_backend_dev_name(dev));
                     llama_free(ctx);
                     return nullptr;
                 }
-                ctx->backends.push_back(backend);
+                ctx->backends.emplace_back(backend);
             }
         }
-#endif
+
+        // add CPU backend
         ctx->backend_cpu = lm_ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
             LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
             llama_free(ctx);
             return nullptr;
         }
-        ctx->backends.push_back(ctx->backend_cpu);
+        ctx->backends.emplace_back(ctx->backend_cpu);
+
+        // create a list of the set_n_threads functions in the backends
+        for (auto & backend : ctx->backends) {
+            lm_ggml_backend_dev_t dev = lm_ggml_backend_get_device(backend.get());
+            lm_ggml_backend_reg_t reg = dev ? lm_ggml_backend_dev_backend_reg(dev) : nullptr;
+            if (reg) {
+                auto lm_ggml_backend_set_n_threads_fn = (lm_ggml_backend_set_n_threads_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_n_threads");
+                if (lm_ggml_backend_set_n_threads_fn) {
+                    ctx->set_n_threads_fns.emplace_back(backend.get(), lm_ggml_backend_set_n_threads_fn);
+                }
+            }
+        }
 
         if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
@@ -17830,7 +19556,7 @@ struct llama_context * llama_new_context_with_model(
             }
 
             LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                      (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
                 lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
@@ -17845,21 +19571,27 @@ struct llama_context * llama_new_context_with_model(
             }
 
             LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
-                    lm_ggml_backend_buffer_name(ctx->buf_output),
-                    lm_ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
+                    lm_ggml_backend_buffer_name(ctx->buf_output.get()),
+                    lm_ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
         }
 
         // scheduler and compute buffers
         {
             // buffer types used for the compute buffer of each backend
             std::vector backend_buft;
-            for (auto * backend : ctx->backends) {
-                if (lm_ggml_backend_is_cpu(backend)) {
-                    // use host buffers for the CPU backend compute buffer
-                    backend_buft.push_back(llama_default_buffer_type_cpu(true));
-                } else {
-                    backend_buft.push_back(lm_ggml_backend_get_default_buffer_type(backend));
+            std::vector backend_ptrs;
+            for (auto & backend : ctx->backends) {
+                auto * buft = lm_ggml_backend_get_default_buffer_type(backend.get());
+                if (lm_ggml_backend_is_cpu(backend.get()) && !model->devices.empty()) {
+                    // use the host buffer of the first device CPU for faster transfer of the intermediate state
+                    auto * dev = model->devices[0];
+                    auto * host_buft = lm_ggml_backend_dev_host_buffer_type(dev);
+                    if (host_buft) {
+                        buft = host_buft;
+                    }
                 }
+                backend_buft.push_back(buft);
+                backend_ptrs.push_back(backend.get());
             }
 
             const size_t max_nodes = llama_model_max_nodes(*model);
@@ -17867,41 +19599,70 @@ struct llama_context * llama_new_context_with_model(
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(lm_ggml_tensor_overhead()*max_nodes + lm_ggml_graph_overhead_custom(max_nodes, false));
 
+            // TODO: move these checks to lm_ggml_backend_sched
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
                 llama_get_device_count(*model) > 1 &&
                 model->n_gpu_layers > (int)model->hparams.n_layer &&
                 model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
                 params.offload_kqv;
-#ifndef LM_GGML_USE_CUDA
-            // pipeline parallelism requires support for async compute and events
-            // currently this is only implemented in the CUDA backend
-            pipeline_parallel = false;
-#endif
-            ctx->sched = lm_ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
+
+            // pipeline parallelism requires support for async compute and events in all devices
+            if (pipeline_parallel) {
+                for (auto & backend : ctx->backends) {
+                    if (lm_ggml_backend_is_cpu(backend.get())) {
+                        // ignore CPU backend
+                        continue;
+                    }
+                    auto * dev = lm_ggml_backend_get_device(backend.get());
+                    lm_ggml_backend_dev_props props;
+                    lm_ggml_backend_dev_get_props(dev, &props);
+                    if (!props.caps.async || !props.caps.events) {
+                        // device does not support async compute or events
+                        pipeline_parallel = false;
+                        break;
+                    }
+                }
+            }
+
+            ctx->sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
 
             if (pipeline_parallel) {
-                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(ctx->sched));
+                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(ctx->sched.get()));
             }
 
-            // build worst-case graph
+            // initialize scheduler with the worst-case graph
             uint32_t n_seqs = 1; // TODO: worst-case number of sequences
             uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
             llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            lm_ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
 
-            // initialize scheduler with the worst-case graph
-            if (!lm_ggml_backend_sched_reserve(ctx->sched, gf)) {
+            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            lm_ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
+
+            // reserve pp graph first so that buffers are only allocated once
+            lm_ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
+            int n_splits_pp = lm_ggml_backend_sched_get_n_splits(ctx->sched.get());
+            int n_nodes_pp = lm_ggml_graph_n_nodes(gf_pp);
+
+            // reserve with tg graph to get the number of splits and nodes
+            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            lm_ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
+            lm_ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
+            int n_splits_tg = lm_ggml_backend_sched_get_n_splits(ctx->sched.get());
+            int n_nodes_tg = lm_ggml_graph_n_nodes(gf_tg);
+
+            // reserve again with pp graph to avoid ggml-alloc reallocations during inference
+            gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
+            if (!lm_ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
                 LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
                 llama_free(ctx);
                 return nullptr;
             }
 
-            for (size_t i = 0; i < ctx->backends.size(); i++) {
-                lm_ggml_backend_t backend = ctx->backends[i];
+            for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+                lm_ggml_backend_t backend = backend_ptrs[i];
                 lm_ggml_backend_buffer_type_t buft = backend_buft[i];
-                size_t size = lm_ggml_backend_sched_get_buffer_size(ctx->sched, backend);
+                size_t size = lm_ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
                 if (size > 1) {
                     LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
                             lm_ggml_backend_buft_name(buft),
@@ -17909,10 +19670,16 @@ struct llama_context * llama_new_context_with_model(
                 }
             }
 
-            // note: the number of splits during measure is higher than during inference due to the kv shift
-            int n_splits = lm_ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, gf->n_nodes);
-            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
+            if (n_nodes_pp == n_nodes_tg) {
+                LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
+            } else {
+                LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
+            }
+            if (n_splits_pp == n_splits_tg) {
+                LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
+            } else {
+                LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
+            }
         }
     }
 
@@ -17923,14 +19690,6 @@ void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
 
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
-    return &ctx->model.vocab;
-}
-
 uint32_t llama_n_ctx(const struct llama_context * ctx) {
     return ctx->cparams.n_ctx;
 }
@@ -17951,6 +19710,34 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
     return model->vocab.type;
 }
 
+int32_t llama_n_vocab(const struct llama_model * model) {
+    return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+int32_t llama_n_head(const struct llama_model * model) {
+    return model->hparams.n_head();
+}
+
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
+    return &ctx->model;
+}
+
+enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
+    return ctx->cparams.pooling_type;
+}
+
 enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
@@ -17964,6 +19751,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5:
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -17980,6 +19768,9 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_ARCTIC:
         case LLM_ARCH_DEEPSEEK2:
         case LLM_ARCH_CHATGLM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+        case LLM_ARCH_CHAMELEON:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -17993,6 +19784,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
         case LLM_ARCH_GEMMA:
@@ -18003,6 +19795,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_CODESHELL:
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
+        case LLM_ARCH_MINICPM3:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
@@ -18013,26 +19806,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
 float llama_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
@@ -18132,6 +19905,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
 bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
+        case LLM_ARCH_RWKV6:  return true;
         default:              return false;
     }
 }
@@ -18165,40 +19939,47 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const
     LM_GGML_ASSERT(cvec.ctxs.empty());
     LM_GGML_ASSERT(cvec.bufs.empty());
 
-    // count layer buffer types
-    std::map buft_layer_count;
-    for (int64_t i = 0; i < model.hparams.n_layer; i++) {
-        buft_layer_count[model.buft_layer[i].buft]++;
-    }
-
-    // allocate contexts
+    // create a context for each buffer type
     std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        int n_layers = it.second;
-        struct lm_ggml_init_params params = {
-            /*.mem_size   =*/ n_layers * lm_ggml_tensor_overhead(),
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        lm_ggml_context * ctx = lm_ggml_init(params);
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
-            return 1;
+    auto ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            struct lm_ggml_init_params params = {
+                /*.mem_size   =*/ model.hparams.n_layer*lm_ggml_tensor_overhead(),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            lm_ggml_context * ctx = lm_ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+            ctx_map[buft] = ctx;
+            cvec.ctxs.emplace_back(ctx);
+            return ctx;
         }
-        ctx_map[it.first] = ctx;
-    }
+        return it->second;
+    };
 
     // make tensors
     cvec.tensors.reserve(model.hparams.n_layer);
     cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
     for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        struct lm_ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
+        lm_ggml_backend_buffer_type_t buft = select_buft(*model.dev_layer.at(il).buft_list,
+            [&](lm_ggml_context * ctx) {
+                lm_ggml_tensor * cur = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, model.hparams.n_embd);
+                lm_ggml_tensor * layer_dir = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, model.hparams.n_embd);
+                return lm_ggml_add(ctx, cur, layer_dir);
+            });
+        lm_ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
+            return false;
+        }
         lm_ggml_tensor * tensor = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, model.hparams.n_embd);
         cvec.tensors.push_back(tensor);
     }
 
     // allocate tensors / buffers and zero
-    cvec.ctxs.reserve(ctx_map.size());
     cvec.bufs.reserve(ctx_map.size());
     for (auto it : ctx_map) {
         lm_ggml_backend_buffer_type_t buft = it.first;
@@ -18209,8 +19990,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const
             return false;
         }
         lm_ggml_backend_buffer_clear(buf, 0);
-        cvec.ctxs.push_back(ctx);
-        cvec.bufs.push_back(buf);
+        cvec.bufs.emplace_back(buf);
     }
 
     return true;
@@ -18448,14 +20228,14 @@ struct llama_data_write {
         // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
 
-    void write_rng(const std::mt19937 & rng) {
-        std::ostringstream rng_ss;
-        rng_ss << rng;
+    //void write_rng(const std::mt19937 & rng) {
+    //    std::ostringstream rng_ss;
+    //    rng_ss << rng;
 
-        const std::string & rng_str = rng_ss.str();
+    //    const std::string & rng_str = rng_ss.str();
 
-        write_string(rng_str);
-    }
+    //    write_string(rng_str);
+    //}
 
     void write_output_ids(struct llama_context * ctx) {
         llama_output_reorder(ctx);
@@ -18675,17 +20455,17 @@ struct llama_data_read {
         // TODO: add more info which needs to be identical but which is not verified otherwise
     }
 
-    void read_rng(std::mt19937 & rng) {
-        std::string rng_str;
-        read_string(rng_str);
+    //void read_rng(std::mt19937 & rng) {
+    //    std::string rng_str;
+    //    read_string(rng_str);
 
-        std::istringstream rng_ss(rng_str);
-        rng_ss >> rng;
+    //    std::istringstream rng_ss(rng_str);
+    //    rng_ss >> rng;
 
-        if (rng_ss.fail()) {
-            throw std::runtime_error("failed to load RNG state");
-        }
-    }
+    //    if (rng_ss.fail()) {
+    //        throw std::runtime_error("failed to load RNG state");
+    //    }
+    //}
 
     void read_output_ids(struct llama_context * ctx) {
         std::vector output_pos;
@@ -19115,8 +20895,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.write_model_info(ctx);
 
-    data_ctx.write_rng(ctx->sampling.rng);
-
     // copy outputs
     data_ctx.write_output_ids(ctx);
     data_ctx.write_logits(ctx);
@@ -19154,9 +20932,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.read_model_info(ctx);
 
-    // set rng
-    data_ctx.read_rng(ctx->sampling.rng);
-
     // set outputs
     data_ctx.read_output_ids(ctx);
     data_ctx.read_logits(ctx);
@@ -19376,16 +21151,16 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
     }
 }
 
-void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
+void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
     ctx->cparams.n_threads       = n_threads;
     ctx->cparams.n_threads_batch = n_threads_batch;
 }
 
-uint32_t llama_n_threads(struct llama_context * ctx) {
+int32_t llama_n_threads(struct llama_context * ctx) {
     return ctx->cparams.n_threads;
 }
 
-uint32_t llama_n_threads_batch(struct llama_context * ctx) {
+int32_t llama_n_threads_batch(struct llama_context * ctx) {
     return ctx->cparams.n_threads_batch;
 }
 
@@ -19404,9 +21179,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
 
 struct llama_batch llama_batch_get_one(
              llama_token * tokens,
-                 int32_t   n_tokens,
-               llama_pos   pos_0,
-            llama_seq_id   seq_id) {
+                 int32_t   n_tokens) {
     return {
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
@@ -19415,9 +21188,6 @@ struct llama_batch llama_batch_get_one(
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
         /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ pos_0,
-        /*all_pos_1      =*/ 1,
-        /*all_seq_id     =*/ seq_id,
     };
 }
 
@@ -19430,9 +21200,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
         /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ 0,
-        /*all_pos_1      =*/ 0,
-        /*all_seq_id     =*/ 0,
     };
 
     if (embd) {
@@ -19472,7 +21239,7 @@ int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
     const int ret = llama_encode_internal(*ctx, batch);
-    if (ret < 0) {
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
 
@@ -19483,7 +21250,7 @@ int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
     const int ret = llama_decode_internal(*ctx, batch);
-    if (ret < 0) {
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
 
@@ -19491,7 +21258,7 @@ int32_t llama_decode(
 }
 
 void llama_synchronize(struct llama_context * ctx) {
-    lm_ggml_backend_sched_synchronize(ctx->sched);
+    lm_ggml_backend_sched_synchronize(ctx->sched.get());
 
     // FIXME: if multiple single tokens are evaluated without a synchronization,
     // the stats will be added to the prompt evaluation stats
@@ -19499,10 +21266,14 @@ void llama_synchronize(struct llama_context * ctx) {
 
     // add the evaluation to the stats
     if (ctx->n_queued_tokens == 1) {
-        ctx->t_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
+        if (!ctx->cparams.no_perf) {
+            ctx->t_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
+        }
         ctx->n_eval++;
     } else if (ctx->n_queued_tokens > 1) {
-        ctx->t_p_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
+        if (!ctx->cparams.no_perf) {
+            ctx->t_p_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
+        }
         ctx->n_p_eval += ctx->n_queued_tokens;
     }
 
@@ -19541,7 +21312,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
                 throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
             }
         } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+            throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
         } else {
             j = ctx->output_ids[i];
         }
@@ -19559,8 +21330,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         LM_GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -19608,8 +21380,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         LM_GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -19656,6 +21429,10 @@ llama_token llama_token_eos(const struct llama_model * model) {
     return llama_token_eos_impl(model->vocab);
 }
 
+llama_token llama_token_eot(const struct llama_model * model) {
+    return llama_token_eot_impl(model->vocab);
+}
+
 llama_token llama_token_cls(const struct llama_model * model) {
     return llama_token_cls_impl(model->vocab);
 }
@@ -19692,8 +21469,28 @@ llama_token llama_token_suffix(const struct llama_model * model) {
     return llama_token_suffix_impl(model->vocab);
 }
 
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
+llama_token llama_token_fim_pre(const struct llama_model * model) {
+    return llama_token_fim_pre_impl(model->vocab);
+}
+
+llama_token llama_token_fim_suf(const struct llama_model * model) {
+    return llama_token_fim_suf_impl(model->vocab);
+}
+
+llama_token llama_token_fim_mid(const struct llama_model * model) {
+    return llama_token_fim_mid_impl(model->vocab);
+}
+
+llama_token llama_token_fim_pad(const struct llama_model * model) {
+    return llama_token_fim_pad_impl(model->vocab);
+}
+
+llama_token llama_token_fim_rep(const struct llama_model * model) {
+    return llama_token_fim_rep_impl(model->vocab);
+}
+
+llama_token llama_token_fim_sep(const struct llama_model * model) {
+    return llama_token_fim_sep_impl(model->vocab);
 }
 
 //
@@ -19993,6 +21790,26 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "[|assistant|]";
         }
+    } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) {
+        // this template requires the model to have "\n\n" as EOT token
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "user") {
+                ss << "User: " << message->content << "\n\nAssistant:";
+            } else {
+                ss << message->content << "\n\n";
+            }
+        }
+    } else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) {
+        // IBM Granite template
+        for (const auto & message : chat) {
+            std::string role(message->role);
+            ss << "<|start_of_role|>" << role << "<|end_of_role|>"
+               << message->content << "<|end_of_text|>\n";
+        }
+        if (add_ass) {
+            ss << "<|start_of_role|>assistant<|end_of_role|>\n";
+        }
     } else {
         // template not supported
         return -1;
@@ -20043,124 +21860,26 @@ int32_t llama_chat_apply_template(
 }
 
 //
-// grammar
+// sampling
 //
 
-struct llama_grammar * llama_grammar_init(
-        const llama_grammar_element ** rules,
-        size_t    n_rules,
-        size_t    start_rule_index) {
-    return llama_grammar_init_impl(rules, n_rules, start_rule_index);
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
-    llama_grammar_free_impl(grammar);
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
-    return llama_grammar_copy_impl(grammar);
-}
-
-void llama_grammar_sample(
-      const struct llama_grammar * grammar,
-      const struct llama_context * ctx,
-          llama_token_data_array * candidates) {
-    llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
+// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
+struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
+    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
 }
 
-void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar) {
-    llama_grammar_sample(grammar, ctx, candidates);
+struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
+    return llama_sampler_init_infill_impl(model->vocab);
 }
 
-void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token) {
-    llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+    return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
 }
 
 //
-// sampling
+// model split
 //
 
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
-    llama_set_rng_seed_impl(&ctx->sampling, seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
-    llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
-    llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
-    llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
-    llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
-    llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
-}
-
-void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present) {
-    llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
-}
-
-void llama_sample_apply_guidance(
-          struct llama_context * ctx,
-                         float * logits,
-                         float * logits_guidance,
-                         float   scale) {
-    llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
-}
-
 int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
     if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@@ -20185,45 +21904,6 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
     return 0;
 }
 
-struct llama_timings llama_get_timings(struct llama_context * ctx) {
-    struct llama_timings result = {
-        /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
-        /*.t_end_ms    =*/ 1.00 * lm_ggml_time_ms(),
-        /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
-        /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
-        /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
-        /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
-
-        /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
-        /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
-        /*.n_eval   =*/ std::max(1, ctx->n_eval),
-    };
-
-    return result;
-}
-
-void llama_print_timings(struct llama_context * ctx) {
-    const llama_timings timings = llama_get_timings(ctx);
-
-    LLAMA_LOG_INFO("\n");
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, timings.t_load_ms);
-    LLAMA_LOG_INFO("%s:      sample time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
-}
-
-void llama_reset_timings(struct llama_context * ctx) {
-    ctx->t_start_us  = lm_ggml_time_us();
-    ctx->t_eval_us   = ctx->n_eval   = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-
-    ctx->sampling.reset_timings();
-}
-
 const char * llama_print_system_info(void) {
     static std::string s;
 
@@ -20235,12 +21915,14 @@ const char * llama_print_system_info(void) {
     s += "AVX512_VBMI = " + std::to_string(lm_ggml_cpu_has_avx512_vbmi()) + " | ";
     s += "AVX512_VNNI = " + std::to_string(lm_ggml_cpu_has_avx512_vnni()) + " | ";
     s += "AVX512_BF16 = " + std::to_string(lm_ggml_cpu_has_avx512_bf16()) + " | ";
+    s += "AMX_INT8 = "    + std::to_string(lm_ggml_cpu_has_amx_int8())    + " | ";
     s += "FMA = "         + std::to_string(lm_ggml_cpu_has_fma())         + " | ";
     s += "NEON = "        + std::to_string(lm_ggml_cpu_has_neon())        + " | ";
     s += "SVE = "         + std::to_string(lm_ggml_cpu_has_sve())         + " | ";
     s += "ARM_FMA = "     + std::to_string(lm_ggml_cpu_has_arm_fma())     + " | ";
     s += "F16C = "        + std::to_string(lm_ggml_cpu_has_f16c())        + " | ";
     s += "FP16_VA = "     + std::to_string(lm_ggml_cpu_has_fp16_va())     + " | ";
+    s += "RISCV_VECT = "  + std::to_string(lm_ggml_cpu_has_riscv_v())     + " | ";
     s += "WASM_SIMD = "   + std::to_string(lm_ggml_cpu_has_wasm_simd())   + " | ";
     s += "BLAS = "        + std::to_string(lm_ggml_cpu_has_blas())        + " | ";
     s += "SSE3 = "        + std::to_string(lm_ggml_cpu_has_sse3())        + " | ";
@@ -20252,7 +21934,43 @@ const char * llama_print_system_info(void) {
     return s.c_str();
 }
 
-void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
+struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
+    struct llama_perf_context_data data = {};
+
+    if (ctx == nullptr) {
+        return data;
+    }
+
+    data.t_start_ms  = 1e-3 * ctx->t_start_us;
+    data.t_load_ms   = 1e-3 * ctx->t_load_us;
+    data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
+    data.t_eval_ms   = 1e-3 * ctx->t_eval_us;
+    data.n_p_eval    = std::max(1, ctx->n_p_eval);
+    data.n_eval      = std::max(1, ctx->n_eval);
+
+    return data;
+}
+
+void llama_perf_context_print(const struct llama_context * ctx) {
+    const auto data = llama_perf_context(ctx);
+
+    const double t_end_ms = 1e-3 * lm_ggml_time_us();
+
+    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
+    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
+    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
+}
+
+void llama_perf_context_reset(struct llama_context * ctx) {
+    ctx->t_start_us  = lm_ggml_time_us();
+    ctx->t_eval_us   = ctx->n_eval = 0;
+    ctx->t_p_eval_us = ctx->n_p_eval = 0;
+}
+
+void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
     fprintf(stream, "\n");
     fprintf(stream, "###########\n");
     fprintf(stream, "# Timings #\n");
@@ -20263,21 +21981,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
             1.0e-3 * ctx->t_eval_us / ctx->n_eval);
     fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
             1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
-            1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
     fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
     fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->sampling.n_sample);
     fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
     fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
     fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
     fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
             1.0e6 * ctx->n_eval / ctx->t_eval_us);
     fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
             1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-    fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
-            1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
 }
 
 // For internal test use
@@ -20288,15 +22000,9 @@ const std::vector> & llama_inter
 }
 
 void llama_log_set(lm_ggml_log_callback log_callback, void * user_data) {
-    g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
-    g_state.log_callback_user_data = user_data;
-#ifdef LM_GGML_USE_METAL
-    lm_ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(LM_GGML_USE_CUDA)
-    lm_ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(LM_GGML_USE_CANN)
-    lm_ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#endif
+    lm_ggml_log_set(log_callback, user_data);
+    g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
+    g_logger_state.log_callback_user_data = user_data;
 }
 
 static void llama_log_internal_v(lm_ggml_log_level level, const char * format, va_list args) {
@@ -20305,12 +22011,12 @@ static void llama_log_internal_v(lm_ggml_log_level level, const char * format, v
     char buffer[128];
     int len = vsnprintf(buffer, 128, format, args);
     if (len < 128) {
-        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
+        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
     } else {
-        char* buffer2 = new char[len+1];
-        vsnprintf(buffer2, len+1, format, args_copy);
+        char * buffer2 = new char[len + 1];
+        vsnprintf(buffer2, len + 1, format, args_copy);
         buffer2[len] = 0;
-        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
+        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
         delete[] buffer2;
     }
     va_end(args_copy);
diff --git a/cpp/llama.h b/cpp/llama.h
index f27bd4b..3a6ae91 100644
--- a/cpp/llama.h
+++ b/cpp/llama.h
@@ -33,12 +33,15 @@
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
+// TODO: use everywhere in the implementation
+#define LLAMA_TOKEN_NULL -1
+
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 8
+#define LLAMA_SESSION_VERSION 9
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
 #define LLAMA_STATE_SEQ_VERSION 2
@@ -53,8 +56,10 @@ extern "C" {
     // TODO: show sample usage
     //
 
+    // struct llama_vocab; // TODO: add in the future
     struct llama_model;
     struct llama_context;
+    struct llama_sampler;
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -66,6 +71,7 @@ extern "C" {
         LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
         LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
         LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
     };
 
     // pre-tokenization types
@@ -96,6 +102,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_BLOOM          = 23,
         LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH   = 24,
         LLAMA_VOCAB_PRE_TYPE_EXAONE         = 25,
+        LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
     };
 
     enum llama_rope_type {
@@ -166,6 +173,8 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q4_0_4_4      = 33, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0_4_8      = 34, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_TQ1_0         = 36, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_TQ2_0         = 37, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -184,6 +193,7 @@ extern "C" {
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
         LLAMA_POOLING_TYPE_LAST = 3,
+        LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
     };
 
     enum llama_attention_type {
@@ -193,11 +203,12 @@ extern "C" {
     };
 
     enum llama_split_mode {
-        LLAMA_SPLIT_MODE_NONE    = 0, // single GPU
-        LLAMA_SPLIT_MODE_LAYER   = 1, // split layers and KV across GPUs
-        LLAMA_SPLIT_MODE_ROW     = 2, // split rows across GPUs
+        LLAMA_SPLIT_MODE_NONE  = 0, // single GPU
+        LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
+        LLAMA_SPLIT_MODE_ROW   = 2, // split layers and KV across GPUs, use tensor parallelism if supported
     };
 
+    // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
     typedef struct llama_token_data {
         llama_token id; // token id
         float logit;    // log-odds of the token
@@ -205,8 +216,11 @@ extern "C" {
     } llama_token_data;
 
     typedef struct llama_token_data_array {
+        // TODO: consider SoA
+        // NOTE: this pointer can be modified by the samplers
         llama_token_data * data;
         size_t size;
+        int64_t selected; // this is the index in the data array (i.e. not the token id)
         bool sorted;
     } llama_token_data_array;
 
@@ -219,8 +233,11 @@ extern "C" {
     // - token  : the token ids of the input (used when embd is NULL)
     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
     // - pos    : the positions of the respective token in the sequence
+    //            (if set to NULL, the token position will be tracked automatically by llama_decode)
     // - seq_id : the sequence to which the respective token belongs
+    //            (if set to NULL, the sequence ID will be assumed to be 0)
     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
+    //            (if set to NULL, only the logits for last token will be returned)
     //
     typedef struct llama_batch {
         int32_t n_tokens;
@@ -231,15 +248,6 @@ extern "C" {
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
         int8_t       *  logits; // TODO: rename this to "output"
-
-        // NOTE: helpers for smooth API transition - can be deprecated in the future
-        //       for future-proof code, use the above fields instead and ignore everything below
-        //
-        // pos[i] = all_pos_0 + i*all_pos_1
-        //
-        llama_pos    all_pos_0;  // used if pos == NULL
-        llama_pos    all_pos_1;  // used if pos == NULL
-        llama_seq_id all_seq_id; // used if seq_id == NULL
     } llama_batch;
 
     enum llama_model_kv_override_type {
@@ -266,10 +274,7 @@ extern "C" {
         int32_t n_gpu_layers; // number of layers to store in VRAM
         enum llama_split_mode split_mode; // how to split the model across multiple GPUs
 
-        // main_gpu interpretation depends on split_mode:
-        // LLAMA_SPLIT_NONE: the GPU that is used for the entire model
-        // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
-        // LLAMA_SPLIT_LAYER: ignored
+        // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
         int32_t main_gpu;
 
         // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
@@ -299,13 +304,12 @@ extern "C" {
     // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
     //       https://github.com/ggerganov/llama.cpp/pull/7544
     struct llama_context_params {
-        uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
         uint32_t n_batch;           // logical maximum batch size that can be submitted to llama_decode
         uint32_t n_ubatch;          // physical maximum batch size
         uint32_t n_seq_max;         // max number of sequences (i.e. distinct states for recurrent models)
-        uint32_t n_threads;         // number of threads to use for generation
-        uint32_t n_threads_batch;   // number of threads to use for batch processing
+        int32_t  n_threads;         // number of threads to use for generation
+        int32_t  n_threads_batch;   // number of threads to use for batch processing
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
@@ -327,11 +331,13 @@ extern "C" {
         enum lm_ggml_type type_k; // data type for K cache [EXPERIMENTAL]
         enum lm_ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
-        // Keep the booleans together to avoid misalignment during copy-by-value.
+        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+        // TODO: move at the end of the struct
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
+        bool no_perf;     // whether to measure performance timings
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
@@ -355,56 +361,14 @@ extern "C" {
         void * kv_overrides;                 // pointer to vector containing overrides
     } llama_model_quantize_params;
 
-    // grammar types
-    struct llama_grammar;
-
-    // grammar element type
-    enum llama_gretype {
-        // end of rule definition
-        LLAMA_GRETYPE_END            = 0,
-
-        // start of alternate definition for rule
-        LLAMA_GRETYPE_ALT            = 1,
-
-        // non-terminal element: reference to rule
-        LLAMA_GRETYPE_RULE_REF       = 2,
-
-        // terminal element: character (code point)
-        LLAMA_GRETYPE_CHAR           = 3,
-
-        // inverse char(s) ([^a], [^a-b] [^abc])
-        LLAMA_GRETYPE_CHAR_NOT       = 4,
+    typedef struct llama_logit_bias {
+        llama_token token;
+        float bias;
+    } llama_logit_bias;
 
-        // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
-        // be an inclusive range ([a-z])
-        LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
-
-        // modifies a preceding LLAMA_GRETYPE_CHAR or
-        // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
-        LLAMA_GRETYPE_CHAR_ALT       = 6,
-
-        // any character (.)
-        LLAMA_GRETYPE_CHAR_ANY       = 7,
-    };
-
-    typedef struct llama_grammar_element {
-        enum llama_gretype type;
-        uint32_t           value; // Unicode code point or rule ID
-    } llama_grammar_element;
-
-    // performance timing information
-    struct llama_timings {
-        double t_start_ms;
-        double t_end_ms;
-        double t_load_ms;
-        double t_sample_ms;
-        double t_p_eval_ms;
-        double t_eval_ms;
-
-        int32_t n_sample;
-        int32_t n_p_eval;
-        int32_t n_eval;
-    };
+    typedef struct llama_sampler_chain_params {
+        bool no_perf; // whether to measure performance timings
+    } llama_sampler_chain_params;
 
     // used in chat template
     typedef struct llama_chat_message {
@@ -416,8 +380,10 @@ extern "C" {
     struct llama_lora_adapter;
 
     // Helpers for getting default parameters
-    LLAMA_API struct llama_model_params llama_model_default_params(void);
-    LLAMA_API struct llama_context_params llama_context_default_params(void);
+    // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
+    LLAMA_API struct llama_model_params          llama_model_default_params(void);
+    LLAMA_API struct llama_context_params        llama_context_default_params(void);
+    LLAMA_API struct llama_sampler_chain_params  llama_sampler_chain_default_params(void);
     LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
 
     // Initialize the llama + ggml backend
@@ -428,15 +394,23 @@ extern "C" {
     //optional:
     LLAMA_API void llama_numa_init(enum lm_ggml_numa_strategy numa);
 
+    // Optional: an auto threadpool gets created in ggml if not passed explicitly
+    LLAMA_API void llama_attach_threadpool(
+               struct   llama_context * ctx,
+            lm_ggml_threadpool_t   threadpool,
+            lm_ggml_threadpool_t   threadpool_batch);
+    LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
+
     // Call once at the end of the program - currently only used for MPI
     LLAMA_API void llama_backend_free(void);
 
     LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
-            struct llama_model_params     params);
+              struct llama_model_params   params);
 
     LLAMA_API void llama_free_model(struct llama_model * model);
 
+    // TODO: rename to llama_init_from_model
     LLAMA_API struct llama_context * llama_new_context_with_model(
                      struct llama_model * model,
             struct llama_context_params   params);
@@ -451,23 +425,24 @@ extern "C" {
     LLAMA_API bool llama_supports_mmap       (void);
     LLAMA_API bool llama_supports_mlock      (void);
     LLAMA_API bool llama_supports_gpu_offload(void);
-
-    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+    LLAMA_API bool llama_supports_rpc        (void);
 
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
-    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
-
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
-
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_n_layer    (const struct llama_model * model);
+    LLAMA_API int32_t llama_n_head     (const struct llama_model * model);
+
+    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
+    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@@ -696,7 +671,7 @@ extern "C" {
     //
 
     // Returns the *actual* size in bytes of the state
-    // (rng, logits, embedding and kv_cache)
+    // (logits, embedding and kv_cache)
     // Only use when saving the state, not when restoring it, otherwise the size may be too small.
     LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
     LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -793,15 +768,15 @@ extern "C" {
     // Decoding
     //
 
-    // Return batch for single sequence of tokens starting at pos_0
+    // Return batch for single sequence of tokens
+    // The sequence ID will be fixed to 0
+    // The position of the tokens will be tracked automatically by llama_decode
     //
     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
     //
     LLAMA_API struct llama_batch llama_batch_get_one(
                   llama_token * tokens,
-                      int32_t   n_tokens,
-                    llama_pos   pos_0,
-                 llama_seq_id   seq_id);
+                      int32_t   n_tokens);
 
     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
     // Each token can be assigned up to n_seq_max sequence ids
@@ -837,13 +812,13 @@ extern "C" {
     // Set the number of threads used for decoding
     // n_threads is the number of threads used for generation (single token)
     // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
-    LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
+    LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
 
     // Get the number of threads used for generation of a single token.
-    LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
+    LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
 
     // Get the number of threads used for prompt and batch processing (multiple token).
-    LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
+    LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
 
     // Set whether the model is in embeddings mode or not
     // If true, embeddings will be returned but logits will not
@@ -891,7 +866,8 @@ extern "C" {
 
     // Get the embeddings for a sequence id
     // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
-    // shape: [n_embd] (1-dimensional)
+    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+    // otherwise: float[n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
 
     //
@@ -913,6 +889,7 @@ extern "C" {
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
+    LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
     LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
     LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -921,15 +898,23 @@ extern "C" {
     LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
     LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
 
-    // Codellama infill tokens
-    LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
-    LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
-    LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
-    LLAMA_API llama_token llama_token_eot   (const struct llama_model * model); // End of infill middle
+    // infill tokens
+    DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
+
+    LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
 
     //
     // Tokenization
     //
+    // The API is thread-safe.
+    //
 
     /// @details Convert the provided text into tokens.
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
@@ -999,121 +984,117 @@ extern "C" {
                                int32_t   length);
 
     //
-    // Grammar
+    // Sampling API
+    //
+    // Sample usage:
+    //
+    //    // prepare the sampling chain at the start
+    //    auto sparams = llama_sampler_chain_default_params();
+    //
+    //    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    //
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
+    //
+    //    // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
+    //    // this sampler will be responsible to select the actual token
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
+    //
+    //    ...
+    //
+    //    // decoding loop:
+    //    while (...) {
+    //        ...
+    //
+    //        llama_decode(ctx, batch);
+    //
+    //        // sample from the logits of the last token in the batch
+    //        const llama_token id = llama_sampler_sample(smpl, ctx, -1);
+    //
+    //        // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
+    //        llama_sampler_accept(smpl, id);
+    //        ...
+    //    }
+    //
+    //    llama_sampler_free(smpl);
+    //
+    // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
+    // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
     //
 
-    /// Initialize a llama_grammar.
-    ///
-    /// @param rules The rule elements of the grammar to initialize.
-    /// @param n_rules The number of rules.
-    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
-    /// @return The initialized llama_grammar or nullptr if initialization failed.
-    LLAMA_API struct llama_grammar * llama_grammar_init(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index);
-
-    LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
-
-    LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
-
-    /// @details Apply constraints from grammar
-    LLAMA_API void llama_grammar_sample(
-            const struct llama_grammar * grammar,
-            const struct llama_context * ctx,
-                llama_token_data_array * candidates);
-    LLAMA_API DEPRECATED(void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar),
-        "use llama_grammar_sample instead");
+    typedef void * llama_sampler_context_t;
 
-    /// @details Accepts the sampled token into the grammar
-    LLAMA_API void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token);
+    // user code can implement the interface below in order to create custom llama_sampler
+    struct llama_sampler_i {
+        const char *           (*name)  (const struct llama_sampler * smpl);                                 // can be NULL
+        void                   (*accept)(      struct llama_sampler * smpl, llama_token token);              // can be NULL
+        void                   (*apply) (      struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
+        void                   (*reset) (      struct llama_sampler * smpl);                                 // can be NULL
+        struct llama_sampler * (*clone) (const struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
+        void                   (*free)  (      struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
 
-    //
-    // Sampling functions
-    //
+        // TODO: API for internal libllama usage for appending the sampling to an existing lm_ggml_cgraph
+        //void (*apply_ggml) (struct llama_sampler * smpl, ...);
+    };
 
-    // Sets the current rng seed.
-    LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
+    struct llama_sampler {
+        struct llama_sampler_i  * iface;
+        llama_sampler_context_t   ctx;
+    };
 
-    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
-    /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
-    LLAMA_API void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present);
-
-    /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
-    /// @param logits Logits extracted from the original generation context.
-    /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
-    /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    LLAMA_API void llama_sample_apply_guidance(
-              struct llama_context * ctx,
-                             float * logits,
-                             float * logits_guidance,
-                             float   scale);
+    // mirror of llama_sampler_i:
+    LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl);
+    LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token);
+    LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p);
+    LLAMA_API void                   llama_sampler_reset (      struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
+    // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
+    LLAMA_API void                   llama_sampler_free  (      struct llama_sampler * smpl);
+
+    // llama_sampler_chain
+    // a type of llama_sampler that can chain multiple samplers one after another
+
+    LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
+
+    // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
+    LLAMA_API void                   llama_sampler_chain_add(      struct llama_sampler * chain, struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
+    LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
+
+    // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
+    LLAMA_API struct llama_sampler * llama_sampler_chain_remove(   struct llama_sampler * chain, int32_t i);
+
+    // available samplers:
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist  (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
-    LLAMA_API void llama_sample_softmax(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
+    DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void),
+        "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_k(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                         int32_t   k,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
 
     /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_p      (float   p, size_t min_keep);
 
     /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
-    LLAMA_API void llama_sample_min_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
-
-    /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
-    LLAMA_API void llama_sample_tail_free(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   z,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_min_p      (float   p, size_t min_keep);
 
     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
-    LLAMA_API void llama_sample_typical(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_typical    (float   p, size_t min_keep);
 
-    /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
-    LLAMA_API void llama_sample_entropy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates_p,
-                           float   min_temp,
-                           float   max_temp,
-                           float   exponent_val);
+    /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp       (float   t);
 
-    LLAMA_API void llama_sample_temp(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   temp);
+    /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext   (float   t, float   delta, float exponent);
+
+    /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
+    LLAMA_API struct llama_sampler * llama_sampler_init_xtc        (float   p, float   t,     size_t min_keep, uint32_t seed);
 
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@@ -1121,36 +1102,94 @@ extern "C" {
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                         int32_t   m,
-                           float * mu);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
+                             int32_t   n_vocab,
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta,
+                             int32_t   m);
 
     /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat_v2(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                           float * mu);
-
-    /// @details Selects the token with the highest probability.
-    ///          Does not compute the token probabilities. Use llama_sample_softmax() instead.
-    LLAMA_API llama_token llama_sample_token_greedy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
+            const struct llama_model * model,
+                          const char * grammar_str,
+                          const char * grammar_root);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
+                             int32_t   n_vocab,         // llama_n_vocab()
+                         llama_token   special_eos_id,  // llama_token_eos()
+                         llama_token   linefeed_id,     // llama_token_nl()
+                             int32_t   penalty_last_n,  // last n tokens to penalize (0 = disable penalty, -1 = context size)
+                               float   penalty_repeat,  // 1.0 = disabled
+                               float   penalty_freq,    // 0.0 = disabled
+                               float   penalty_present, // 0.0 = disabled
+                                bool   penalize_nl,     // consider newlines as a repeatable token
+                                bool   ignore_eos);     // ignore the end-of-sequence token
+
+    ///  @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
+    LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
+            const struct llama_model *  model,
+                               float    dry_multiplier,
+                               float    dry_base,
+                             int32_t    dry_allowed_length,
+                             int32_t    dry_penalty_last_n,
+                          const char ** seq_breakers,
+                              size_t    num_breakers);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
+                             int32_t   n_vocab,
+                             int32_t   n_logit_bias,
+              const llama_logit_bias * logit_bias);
+
+    // this sampler is meant to be used for fill-in-the-middle infilling
+    // it's supposed to be used after top_k + top_p sampling
+    //
+    // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
+    // 2. combine probs of tokens that have the same prefix
+    //
+    // example:
+    //
+    // - before:
+    //   "hel":   0.5
+    //   "hell":  0.2
+    //   "hello": 0.1
+    //   "dummy": 0.1
+    //
+    // - after:
+    //   "hel":   0.8
+    //   "dummy": 0.1
+    //
+    // 3. discard non-EOG tokens with low prob
+    // 4. if no tokens are left -> pick EOT
+    //
+    LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
 
-    /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
-    LLAMA_API llama_token llama_sample_token(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
+    LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
+
+    /// @details Sample and accept a token from the idx-th output of the last evaluation
+    //
+    // Shorthand for:
+    //    const auto * logits = llama_get_logits_ith(ctx, idx);
+    //    llama_token_data_array cur_p = { ... init from logits ... };
+    //    llama_sampler_apply(smpl, &cur_p);
+    //    auto token = cur_p.data[cur_p.selected].id;
+    //    llama_sampler_accept(smpl, token);
+    //    return token;
+    // Returns the sampled token
+    LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
+
+    // TODO: extend in the future
+    //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
 
     //
     // Model split
@@ -1166,12 +1205,6 @@ extern "C" {
     //  Returns the split_prefix length.
     LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
 
-    // Performance information
-    LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
-
-    LLAMA_API void llama_print_timings(struct llama_context * ctx);
-    LLAMA_API void llama_reset_timings(struct llama_context * ctx);
-
     // Print system information
     LLAMA_API const char * llama_print_system_info(void);
 
@@ -1179,65 +1212,41 @@ extern "C" {
     // If this is not called, or NULL is supplied, everything is output on stderr.
     LLAMA_API void llama_log_set(lm_ggml_log_callback log_callback, void * user_data);
 
-    LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
-
-#ifdef __cplusplus
-}
-#endif
-
-// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
-#ifdef LLAMA_API_INTERNAL
-
-#include 
-#include 
-#include 
-
-struct lm_ggml_tensor;
-
-const std::vector> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-);
-
-struct llama_partial_utf8 {
-    uint32_t value;    // bit value so far (unshifted)
-    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
-};
-
-struct llama_grammar_candidate {
-    size_t               index;
-    const uint32_t     * code_points;
-    llama_partial_utf8   partial_utf8;
-};
+    //
+    // Performance utils
+    //
+    // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
+    //
 
-using llama_grammar_rule  = std::vector<      llama_grammar_element>;
-using llama_grammar_stack = std::vector;
+    struct llama_perf_context_data {
+        double t_start_ms;
+        double t_load_ms;
+        double t_p_eval_ms;
+        double t_eval_ms;
 
-using llama_grammar_rules      = std::vector;
-using llama_grammar_stacks     = std::vector;
-using llama_grammar_candidates = std::vector;
+        int32_t n_p_eval;
+        int32_t n_eval;
+    };
 
-const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
-      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
+    struct llama_perf_sampler_data {
+        double t_sample_ms;
 
-void llama_grammar_accept(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const uint32_t chr,
-              llama_grammar_stacks & new_stacks);
+        int32_t n_sample;
+    };
 
-std::vector llama_grammar_reject_candidates_for_stack(
-        const llama_grammar_rules & rules,
-        const llama_grammar_stack & stack,
-        const llama_grammar_candidates & candidates);
+    LLAMA_API struct llama_perf_context_data llama_perf_context      (const struct llama_context * ctx);
+    LLAMA_API void                           llama_perf_context_print(const struct llama_context * ctx);
+    LLAMA_API void                           llama_perf_context_reset(      struct llama_context * ctx);
 
-std::pair, llama_partial_utf8> decode_utf8(
-        const std::string & src,
-        llama_partial_utf8 partial_start);
+    // NOTE: the following work only with samplers constructed via llama_sampler_chain_init
+    LLAMA_API struct llama_perf_sampler_data llama_perf_sampler      (const struct llama_sampler * chain);
+    LLAMA_API void                           llama_perf_sampler_print(const struct llama_sampler * chain);
+    LLAMA_API void                           llama_perf_sampler_reset(      struct llama_sampler * chain);
 
-// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
-// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
+    LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
 
-#endif // LLAMA_API_INTERNAL
+#ifdef __cplusplus
+}
+#endif
 
 #endif // LLAMA_H
diff --git a/cpp/log.cpp b/cpp/log.cpp
new file mode 100644
index 0000000..6870605
--- /dev/null
+++ b/cpp/log.cpp
@@ -0,0 +1,434 @@
+#if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
+#include 
+#endif
+#include "log.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
+
+void common_log_set_verbosity_thold(int verbosity) {
+    common_log_verbosity_thold = verbosity;
+}
+
+#define LOG_COL_DEFAULT "\033[0m"
+#define LOG_COL_BOLD    "\033[1m"
+#define LOG_COL_RED     "\033[31m"
+#define LOG_COL_GREEN   "\033[32m"
+#define LOG_COL_YELLOW  "\033[33m"
+#define LOG_COL_BLUE    "\033[34m"
+#define LOG_COL_MAGENTA "\033[35m"
+#define LOG_COL_CYAN    "\033[36m"
+#define LOG_COL_WHITE   "\033[37m"
+
+static int64_t t_us() {
+    return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count();
+}
+
+// colors
+enum common_log_col : int {
+    COMMON_LOG_COL_DEFAULT = 0,
+    COMMON_LOG_COL_BOLD,
+    COMMON_LOG_COL_RED,
+    COMMON_LOG_COL_GREEN,
+    COMMON_LOG_COL_YELLOW,
+    COMMON_LOG_COL_BLUE,
+    COMMON_LOG_COL_MAGENTA,
+    COMMON_LOG_COL_CYAN,
+    COMMON_LOG_COL_WHITE,
+};
+
+// disable colors by default
+static std::vector g_col = {
+    "",
+    "",
+    "",
+    "",
+    "",
+    "",
+    "",
+    "",
+    "",
+};
+
+struct common_log_entry {
+    enum lm_ggml_log_level level;
+
+    bool prefix;
+
+    int64_t timestamp;
+
+    std::vector msg;
+
+    // signals the worker thread to stop
+    bool is_end;
+
+    #if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
+    void android_print() const {
+        int android_log_priority;
+        switch (level) {
+            case LM_GGML_LOG_LEVEL_INFO:
+                android_log_priority = ANDROID_LOG_INFO;
+                break;
+            case LM_GGML_LOG_LEVEL_WARN:
+                android_log_priority = ANDROID_LOG_WARN;
+                break;
+            case LM_GGML_LOG_LEVEL_ERROR:
+                android_log_priority = ANDROID_LOG_ERROR;
+                break;
+            case LM_GGML_LOG_LEVEL_DEBUG:
+                android_log_priority = ANDROID_LOG_DEBUG;
+                break;
+            default:
+                android_log_priority = ANDROID_LOG_DEFAULT;
+                break;
+        }
+
+        const char * tag = "RNLLAMA_LOG_ANDROID"; 
+        __android_log_print(android_log_priority, tag, "%s", msg.data());
+    }
+    #endif
+
+    void print(FILE * file = nullptr) const {
+        #if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
+        android_print();
+        #else
+        FILE * fcur = file;
+        if (!fcur) {
+            // stderr displays DBG messages only when their verbosity level is not higher than the threshold
+            // these messages will still be logged to a file
+            if (level == LM_GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
+                return;
+            }
+
+            fcur = stdout;
+
+            if (level != LM_GGML_LOG_LEVEL_NONE) {
+                fcur = stderr;
+            }
+        }
+
+        if (level != LM_GGML_LOG_LEVEL_NONE && level != LM_GGML_LOG_LEVEL_CONT && prefix) {
+            if (timestamp) {
+                // [M.s.ms.us]
+                fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
+                        g_col[COMMON_LOG_COL_BLUE],
+                        (int) (timestamp / 1000000 / 60),
+                        (int) (timestamp / 1000000 % 60),
+                        (int) (timestamp / 1000 % 1000),
+                        (int) (timestamp % 1000),
+                        g_col[COMMON_LOG_COL_DEFAULT]);
+            }
+
+            switch (level) {
+                case LM_GGML_LOG_LEVEL_INFO:  fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN],   g_col[COMMON_LOG_COL_DEFAULT]); break;
+                case LM_GGML_LOG_LEVEL_WARN:  fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], ""                        ); break;
+                case LM_GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED],     ""                        ); break;
+                case LM_GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW],  ""                        ); break;
+                default:
+                    break;
+            }
+        }
+
+        fprintf(fcur, "%s", msg.data());
+
+        if (level == LM_GGML_LOG_LEVEL_WARN || level == LM_GGML_LOG_LEVEL_ERROR || level == LM_GGML_LOG_LEVEL_DEBUG) {
+            fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
+        }
+
+        fflush(fcur);
+        #endif
+    }
+};
+
+struct common_log {
+    // default capacity - will be expanded if needed
+    common_log() : common_log(256) {}
+
+    common_log(size_t capacity) {
+        file = nullptr;
+        prefix = false;
+        timestamps = false;
+        running = false;
+        t_start = t_us();
+
+        // initial message size - will be expanded if longer messages arrive
+        entries.resize(capacity);
+        for (auto & entry : entries) {
+            entry.msg.resize(256);
+        }
+
+        head = 0;
+        tail = 0;
+
+        resume();
+    }
+
+    ~common_log() {
+        pause();
+        if (file) {
+            fclose(file);
+        }
+    }
+
+private:
+    std::mutex mtx;
+    std::thread thrd;
+    std::condition_variable cv;
+
+    FILE * file;
+
+    bool prefix;
+    bool timestamps;
+    bool running;
+
+    int64_t t_start;
+
+    // ring buffer of entries
+    std::vector entries;
+    size_t head;
+    size_t tail;
+
+    // worker thread copies into this
+    common_log_entry cur;
+
+public:
+    void add(enum lm_ggml_log_level level, const char * fmt, va_list args) {
+        std::lock_guard lock(mtx);
+
+        if (!running) {
+            // discard messages while the worker thread is paused
+            return;
+        }
+
+        auto & entry = entries[tail];
+
+        {
+            // cannot use args twice, so make a copy in case we need to expand the buffer
+            va_list args_copy;
+            va_copy(args_copy, args);
+
+#if 1
+            const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
+            if (n >= entry.msg.size()) {
+                entry.msg.resize(n + 1);
+                vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
+            }
+#else
+            // hack for bolding arguments
+
+            std::stringstream ss;
+            for (int i = 0; fmt[i] != 0; i++) {
+                if (fmt[i] == '%') {
+                    ss << LOG_COL_BOLD;
+                    while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
+                    ss << LOG_COL_DEFAULT;
+                    if (fmt[i] == 0) break;
+                }
+                ss << fmt[i];
+            }
+            const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
+            if (n >= entry.msg.size()) {
+                entry.msg.resize(n + 1);
+                vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
+            }
+#endif
+        }
+
+        entry.level = level;
+        entry.prefix = prefix;
+        entry.timestamp = 0;
+        if (timestamps) {
+            entry.timestamp = t_us() - t_start;
+        }
+        entry.is_end = false;
+
+        tail = (tail + 1) % entries.size();
+        if (tail == head) {
+            // expand the buffer
+            std::vector new_entries(2*entries.size());
+
+            size_t new_tail = 0;
+
+            do {
+                new_entries[new_tail] = std::move(entries[head]);
+
+                head     = (head     + 1) % entries.size();
+                new_tail = (new_tail + 1);
+            } while (head != tail);
+
+            head = 0;
+            tail = new_tail;
+
+            for (size_t i = tail; i < new_entries.size(); i++) {
+                new_entries[i].msg.resize(256);
+            }
+
+            entries = std::move(new_entries);
+        }
+
+        cv.notify_one();
+    }
+
+    void resume() {
+        std::lock_guard lock(mtx);
+
+        if (running) {
+            return;
+        }
+
+        running = true;
+
+        thrd = std::thread([this]() {
+            while (true) {
+                {
+                    std::unique_lock lock(mtx);
+                    cv.wait(lock, [this]() { return head != tail; });
+
+                    cur = entries[head];
+
+                    head = (head + 1) % entries.size();
+                }
+
+                if (cur.is_end) {
+                    break;
+                }
+
+                cur.print(); // stdout and stderr
+
+                if (file) {
+                    cur.print(file);
+                }
+            }
+        });
+    }
+
+    void pause() {
+        {
+            std::lock_guard lock(mtx);
+
+            if (!running) {
+                return;
+            }
+
+            running = false;
+
+            // push an entry to signal the worker thread to stop
+            {
+                auto & entry = entries[tail];
+                entry.is_end = true;
+
+                tail = (tail + 1) % entries.size();
+            }
+
+            cv.notify_one();
+        }
+
+        thrd.join();
+    }
+
+    void set_file(const char * path) {
+        pause();
+
+        if (file) {
+            fclose(file);
+        }
+
+        if (path) {
+            file = fopen(path, "w");
+        } else {
+            file = nullptr;
+        }
+
+        resume();
+    }
+
+    void set_colors(bool colors) {
+        pause();
+
+        if (colors) {
+            g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
+            g_col[COMMON_LOG_COL_BOLD]    = LOG_COL_BOLD;
+            g_col[COMMON_LOG_COL_RED]     = LOG_COL_RED;
+            g_col[COMMON_LOG_COL_GREEN]   = LOG_COL_GREEN;
+            g_col[COMMON_LOG_COL_YELLOW]  = LOG_COL_YELLOW;
+            g_col[COMMON_LOG_COL_BLUE]    = LOG_COL_BLUE;
+            g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
+            g_col[COMMON_LOG_COL_CYAN]    = LOG_COL_CYAN;
+            g_col[COMMON_LOG_COL_WHITE]   = LOG_COL_WHITE;
+        } else {
+            for (size_t i = 0; i < g_col.size(); i++) {
+                g_col[i] = "";
+            }
+        }
+
+        resume();
+    }
+
+    void set_prefix(bool prefix) {
+        std::lock_guard lock(mtx);
+
+        this->prefix = prefix;
+    }
+
+    void set_timestamps(bool timestamps) {
+        std::lock_guard lock(mtx);
+
+        this->timestamps = timestamps;
+    }
+};
+
+//
+// public API
+//
+
+struct common_log * common_log_init() {
+    return new common_log;
+}
+
+struct common_log * common_log_main() {
+    static struct common_log log;
+
+    return &log;
+}
+
+void common_log_pause(struct common_log * log) {
+    log->pause();
+}
+
+void common_log_resume(struct common_log * log) {
+    log->resume();
+}
+
+void common_log_free(struct common_log * log) {
+    delete log;
+}
+
+void common_log_add(struct common_log * log, enum lm_ggml_log_level level, const char * fmt, ...) {
+    va_list args;
+    va_start(args, fmt);
+    log->add(level, fmt, args);
+    va_end(args);
+}
+
+void common_log_set_file(struct common_log * log, const char * file) {
+    log->set_file(file);
+}
+
+void common_log_set_colors(struct common_log * log, bool colors) {
+    log->set_colors(colors);
+}
+
+void common_log_set_prefix(struct common_log * log, bool prefix) {
+    log->set_prefix(prefix);
+}
+
+void common_log_set_timestamps(struct common_log * log, bool timestamps) {
+    log->set_timestamps(timestamps);
+}
diff --git a/cpp/log.h b/cpp/log.h
index daad7e4..de27c35 100644
--- a/cpp/log.h
+++ b/cpp/log.h
@@ -1,737 +1,92 @@
 #pragma once
 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
+#include "ggml.h" // for lm_ggml_log_level
 
-// --------------------------------
-//
-// Basic usage:
-//
-// --------
-//
-//  The LOG() and LOG_TEE() macros are ready to go by default
-//   they do not require any initialization.
-//
-//  LOGLN() and LOG_TEELN() are variants which automatically
-//   include \n character at the end of the log string.
-//
-//  LOG() behaves exactly like printf, by default writing to a logfile.
-//  LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ).
-//
-//  Default logfile is named
-//   "llama..log"
-//  Default LOG_TEE() secondary output target is
-//   stderr
-//
-//  Logs can be dynamically disabled or enabled using functions:
-//   log_disable()
-//  and
-//   log_enable()
-//
-//  A log target can be changed with:
-//   log_set_target( string )
-//    creating and opening, or re-opening a file by string filename
-//  or
-//   log_set_target( FILE* )
-//    allowing to point at stderr, stdout, or any valid FILE* file handler.
-//
-// --------
-//
-// End of Basic usage.
-//
-// --------------------------------
-
-// Specifies a log target.
-//  default uses log_handler() with "llama.log" log file
-//  this can be changed, by defining LOG_TARGET
-//  like so:
-//
-//  #define LOG_TARGET (a valid FILE*)
-//  #include "log.h"
-//
-//  or it can be simply redirected to stdout or stderr
-//  like so:
-//
-//  #define LOG_TARGET stderr
-//  #include "log.h"
-//
-//  The log target can also be redirected to a different function
-//  like so:
-//
-//  #define LOG_TARGET log_handler_different()
-//  #include "log.h"
-//
-//  FILE* log_handler_different()
-//  {
-//      return stderr;
-//  }
-//
-//  or:
-//
-//  #define LOG_TARGET log_handler_another_one("somelog.log")
-//  #include "log.h"
-//
-//  FILE* log_handler_another_one(char*filename)
-//  {
-//      static FILE* logfile = nullptr;
-//      (...)
-//      if( !logfile )
-//      {
-//          fopen(...)
-//      }
-//      (...)
-//      return logfile
-//  }
-//
-#ifndef LOG_TARGET
-    #define LOG_TARGET log_handler()
-#endif
-
-#ifndef LOG_TEE_TARGET
-    #define LOG_TEE_TARGET stderr
+#ifndef __GNUC__
+#    define LOG_ATTRIBUTE_FORMAT(...)
+#elif defined(__MINGW32__)
+#    define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#    define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
 #endif
 
-// Utility for synchronizing log configuration state
-//  since std::optional was introduced only in c++17
-enum LogTriState
-{
-    LogTriStateSame,
-    LogTriStateFalse,
-    LogTriStateTrue
-};
-
-// Utility to obtain "pid" like unique process id and use it when creating log files.
-inline std::string log_get_pid()
-{
-   static std::string pid;
-   if (pid.empty())
-   {
-       // std::this_thread::get_id() is the most portable way of obtaining a "process id"
-       //  it's not the same as "pid" but is unique enough to solve multiple instances
-       //  trying to write to the same log.
-       std::stringstream ss;
-       ss << std::this_thread::get_id();
-       pid = ss.str();
-   }
-
-   return pid;
-}
-
-// Utility function for generating log file names with unique id based on thread id.
-//  invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log"
-//  where the number is a runtime id of the current thread.
+#define LOG_DEFAULT_DEBUG 1
+#define LOG_DEFAULT_LLAMA 0
 
-#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(LogTriStateSame, log_file_basename, log_file_extension)
+// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
+// set via common_log_set_verbosity()
+extern int common_log_verbosity_thold;
 
-// INTERNAL, DO NOT USE
-inline std::string log_filename_generator_impl(LogTriState multilog, const std::string & log_file_basename, const std::string & log_file_extension)
-{
-    static bool _multilog = false;
+void common_log_set_verbosity_thold(int verbosity); // not thread-safe
 
-    if (multilog != LogTriStateSame)
-    {
-        _multilog = multilog == LogTriStateTrue;
-    }
+// the common_log uses an internal worker thread to print/write log messages
+// when the worker thread is paused, incoming log messages are discarded
+struct common_log;
 
-    std::stringstream buf;
-
-    buf << log_file_basename;
-    if (_multilog)
-    {
-        buf << ".";
-        buf << log_get_pid();
-    }
-    buf << ".";
-    buf << log_file_extension;
-
-    return buf.str();
-}
-
-#ifndef LOG_DEFAULT_FILE_NAME
-    #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log")
-#endif
+struct common_log * common_log_init();
+struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
+void                common_log_pause (struct common_log * log); // pause  the worker thread, not thread-safe
+void                common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
+void                common_log_free  (struct common_log * log);
 
-// Utility for turning #define values into string literals
-//  so we can have a define for stderr and
-//  we can print "stderr" instead of literal stderr, etc.
-#define LOG_STRINGIZE1(s) #s
-#define LOG_STRINGIZE(s) LOG_STRINGIZE1(s)
+LOG_ATTRIBUTE_FORMAT(3, 4)
+void common_log_add(struct common_log * log, enum lm_ggml_log_level level, const char * fmt, ...);
 
-#define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET)
-
-// Allows disabling timestamps.
-//  in order to disable, define LOG_NO_TIMESTAMPS
-//  like so:
+// defaults: file = NULL, colors = false, prefix = false, timestamps = false
 //
-//  #define LOG_NO_TIMESTAMPS
-//  #include "log.h"
+// regular log output:
 //
-#ifndef LOG_NO_TIMESTAMPS
-    #ifndef _MSC_VER
-        #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] "
-        #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count()
-    #else
-        #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] "
-        #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count()
-    #endif
-#else
-    #define LOG_TIMESTAMP_FMT "%s"
-    #define LOG_TIMESTAMP_VAL ,""
-#endif
-
-#ifdef LOG_TEE_TIMESTAMPS
-    #ifndef _MSC_VER
-        #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] "
-        #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count()
-    #else
-        #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] "
-        #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count()
-    #endif
-#else
-    #define LOG_TEE_TIMESTAMP_FMT "%s"
-    #define LOG_TEE_TIMESTAMP_VAL ,""
-#endif
-
-// Allows disabling file/line/function prefix
-//  in order to disable, define LOG_NO_FILE_LINE_FUNCTION
-//  like so:
+//   lm_ggml_backend_metal_log_allocated_size: allocated buffer, size =  6695.84 MiB, ( 6695.91 / 21845.34)
+//   llm_load_tensors: ggml ctx size =    0.27 MiB
+//   llm_load_tensors: offloading 32 repeating layers to GPU
+//   llm_load_tensors: offloading non-repeating layers to GPU
 //
-//  #define LOG_NO_FILE_LINE_FUNCTION
-//  #include "log.h"
+// with prefix = true, timestamps = true, the log output will look like this:
 //
-#ifndef LOG_NO_FILE_LINE_FUNCTION
-    #ifndef _MSC_VER
-        #define LOG_FLF_FMT "[%24s:%5d][%24s] "
-        #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__
-    #else
-        #define LOG_FLF_FMT "[%24s:%5ld][%24s] "
-        #define LOG_FLF_VAL , __FILE__, (long)__LINE__, __FUNCTION__
-    #endif
-#else
-    #define LOG_FLF_FMT "%s"
-    #define LOG_FLF_VAL ,""
-#endif
-
-#ifdef LOG_TEE_FILE_LINE_FUNCTION
-    #ifndef _MSC_VER
-        #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] "
-        #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__
-    #else
-        #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] "
-        #define LOG_TEE_FLF_VAL , __FILE__, (long)__LINE__, __FUNCTION__
-    #endif
-#else
-    #define LOG_TEE_FLF_FMT "%s"
-    #define LOG_TEE_FLF_VAL ,""
-#endif
-
-// INTERNAL, DO NOT USE
-//  USE LOG() INSTEAD
+//   0.00.035.060 D lm_ggml_backend_metal_log_allocated_size: allocated buffer, size =  6695.84 MiB, ( 6695.91 / 21845.34)
+//   0.00.035.064 I llm_load_tensors: ggml ctx size =    0.27 MiB
+//   0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
+//   0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
 //
-#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__)
-    #define LOG_IMPL(str, ...)                                                                                      \
-    do {                                                                                                            \
-        if (LOG_TARGET != nullptr)                                                                                  \
-        {                                                                                                           \
-            fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \
-            fflush(LOG_TARGET);                                                                                     \
-        }                                                                                                           \
-    } while (0)
-#else
-    #define LOG_IMPL(str, ...)                                                                                           \
-    do {                                                                                                                 \
-        if (LOG_TARGET != nullptr)                                                                                       \
-        {                                                                                                                \
-            fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \
-            fflush(LOG_TARGET);                                                                                          \
-        }                                                                                                                \
-    } while (0)
-#endif
-
-// INTERNAL, DO NOT USE
-//  USE LOG_TEE() INSTEAD
+// I - info    (stdout, V = 0)
+// W - warning (stderr, V = 0)
+// E - error   (stderr, V = 0)
+// D - debug   (stderr, V = LOG_DEFAULT_DEBUG)
 //
-#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__)
-    #define LOG_TEE_IMPL(str, ...)                                                                                                      \
-    do {                                                                                                                                \
-        if (LOG_TARGET != nullptr)                                                                                                      \
-        {                                                                                                                               \
-            fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__);                     \
-            fflush(LOG_TARGET);                                                                                                         \
-        }                                                                                                                               \
-        if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr)                         \
-        {                                                                                                                               \
-            fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \
-            fflush(LOG_TEE_TARGET);                                                                                                     \
-        }                                                                                                                               \
-    } while (0)
-#else
-    #define LOG_TEE_IMPL(str, ...)                                                                                                           \
-    do {                                                                                                                                     \
-        if (LOG_TARGET != nullptr)                                                                                                           \
-        {                                                                                                                                    \
-            fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__);                     \
-            fflush(LOG_TARGET);                                                                                                              \
-        }                                                                                                                                    \
-        if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr)                              \
-        {                                                                                                                                    \
-            fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \
-            fflush(LOG_TEE_TARGET);                                                                                                          \
-        }                                                                                                                                    \
-    } while (0)
-#endif
 
-// The '\0' as a last argument, is a trick to bypass the silly
-//  "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro"
-//  so we can have a single macro which can be called just like printf.
+void common_log_set_file      (struct common_log * log, const char * file);       // not thread-safe
+void common_log_set_colors    (struct common_log * log,       bool   colors);     // not thread-safe
+void common_log_set_prefix    (struct common_log * log,       bool   prefix);     // whether to output prefix to each log
+void common_log_set_timestamps(struct common_log * log,       bool   timestamps); // whether to output timestamps in the prefix
 
-// Main LOG macro.
-//  behaves like printf, and supports arguments the exact same way.
+// helper macros for logging
+// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
 //
-#if !defined(_MSC_VER) || defined(__clang__)
-    #define LOG(...) LOG_IMPL(__VA_ARGS__, "")
-#else
-    #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "")
-#endif
-
-// Main TEE macro.
-//  does the same as LOG
-//  and
-//  simultaneously writes stderr.
+// for example:
 //
-// Secondary target can be changed just like LOG_TARGET
-//  by defining LOG_TEE_TARGET
+//   LOG_DBG("this is a debug message: %d\n", expensive_function());
+//
+// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
 //
-#if !defined(_MSC_VER) || defined(__clang__)
-    #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "")
-#else
-    #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", ##__VA_ARGS__, "")
-#endif
-
-// LOG macro variants with auto endline.
-#if !defined(_MSC_VER) || defined(__clang__)
-    #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n")
-    #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n")
-#else
-    #define LOGLN(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "\n")
-    #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", ##__VA_ARGS__, "\n")
-#endif
-
-#if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
-#include 
-#define LLAMA_ANDROID_LOG_TAG "RNLLAMA_LOG_ANDROID"
-#undef LOG
-#undef LOG_TEE
-#undef LOGLN
-#undef LOG_TEELN
-#define LOG(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-#define LOG_TEE(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-#define LOGLN(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-#define LOG_TEELN(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-#endif
-
-// INTERNAL, DO NOT USE
-inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr)
-{
-    static bool _initialized = false;
-    static bool _append = false;
-    static bool _disabled = filename.empty() && target == nullptr;
-    static std::string log_current_filename{filename};
-    static FILE *log_current_target{target};
-    static FILE *logfile = nullptr;
-
-    if (change)
-    {
-        if (append != LogTriStateSame)
-        {
-            _append = append == LogTriStateTrue;
-            return logfile;
-        }
-
-        if (disable == LogTriStateTrue)
-        {
-            // Disable primary target
-            _disabled = true;
-        }
-        // If previously disabled, only enable, and keep previous target
-        else if (disable == LogTriStateFalse)
-        {
-            _disabled = false;
-        }
-        // Otherwise, process the arguments
-        else if (log_current_filename != filename || log_current_target != target)
-        {
-            _initialized = false;
-        }
-    }
-
-    if (_disabled)
-    {
-        // Log is disabled
-        return nullptr;
-    }
-
-    if (_initialized)
-    {
-        // with fallback in case something went wrong
-        return logfile ? logfile : stderr;
-    }
-
-    // do the (re)initialization
-    if (target != nullptr)
-    {
-        if (logfile != nullptr && logfile != stdout && logfile != stderr)
-        {
-            fclose(logfile);
-        }
-
-        log_current_filename = LOG_DEFAULT_FILE_NAME;
-        log_current_target = target;
-
-        logfile = target;
-    }
-    else
-    {
-        if (log_current_filename != filename)
-        {
-            if (logfile != nullptr && logfile != stdout && logfile != stderr)
-            {
-                fclose(logfile);
-            }
-        }
-
-        logfile = fopen(filename.c_str(), _append ? "a" : "w");
-    }
-
-    if (!logfile)
-    {
-        //  Verify whether the file was opened, otherwise fallback to stderr
-        logfile = stderr;
-
-        fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno));
-        fflush(stderr);
-
-        // At this point we let the init flag be to true below, and let the target fallback to stderr
-        //  otherwise we would repeatedly fopen() which was already unsuccessful
-    }
-
-    _initialized = true;
-
-    return logfile ? logfile : stderr;
-}
-
-// INTERNAL, DO NOT USE
-inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME)
-{
-    return log_handler1_impl(change, append, disable, filename, target);
-}
-
-// Disables logs entirely at runtime.
-//  Makes LOG() and LOG_TEE() produce no output,
-//  until enabled back.
-#define log_disable() log_disable_impl()
-
-// INTERNAL, DO NOT USE
-inline FILE *log_disable_impl()
-{
-    return log_handler1_impl(true, LogTriStateSame, LogTriStateTrue);
-}
-
-// Enables logs at runtime.
-#define log_enable() log_enable_impl()
-
-// INTERNAL, DO NOT USE
-inline FILE *log_enable_impl()
-{
-    return log_handler1_impl(true, LogTriStateSame, LogTriStateFalse);
-}
-
-// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*)
-#define log_set_target(target) log_set_target_impl(target)
-
-// INTERNAL, DO NOT USE
-inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, LogTriStateSame, filename); }
-inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, LogTriStateSame, target); }
-
-// INTERNAL, DO NOT USE
-inline FILE *log_handler() { return log_handler1_impl(); }
-
-// Enable or disable creating separate log files for each run.
-//  can ONLY be invoked BEFORE first log use.
-#define log_multilog(enable) log_filename_generator_impl((enable) ? LogTriStateTrue : LogTriStateFalse, "", "")
-// Enable or disable append mode for log file.
-//  can ONLY be invoked BEFORE first log use.
-#define log_append(enable) log_append_impl(enable)
-// INTERNAL, DO NOT USE
-inline FILE *log_append_impl(bool enable)
-{
-    return log_handler1_impl(true, enable ? LogTriStateTrue : LogTriStateFalse, LogTriStateSame);
-}
-
-inline void log_test()
-{
-    log_disable();
-    LOG("01 Hello World to nobody, because logs are disabled!\n");
-    log_enable();
-    LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET));
-    LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n");
-    log_set_target(stderr);
-    LOG("04 Hello World to stderr!\n");
-    LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n");
-    log_set_target(LOG_DEFAULT_FILE_NAME);
-    LOG("06 Hello World to default log file!\n");
-    log_set_target(stdout);
-    LOG("07 Hello World to stdout!\n");
-    log_set_target(LOG_DEFAULT_FILE_NAME);
-    LOG("08 Hello World to default log file again!\n");
-    log_disable();
-    LOG("09 Hello World _1_ into the void!\n");
-    log_enable();
-    LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n");
-    log_disable();
-    log_set_target("llama.anotherlog.log");
-    LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n");
-    log_enable();
-    LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n");
-    log_set_target("llama.yetanotherlog.log");
-    LOG("13 Hello World this time in yet new file?\n");
-    log_set_target(log_filename_generator("llama_autonamed", "log"));
-    LOG("14 Hello World in log with generated filename!\n");
-#ifdef _MSC_VER
-    LOG_TEE("15 Hello msvc TEE without arguments\n");
-    LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test");
-    LOG_TEELN("17 Hello msvc TEELN without arguments\n");
-    LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test");
-    LOG("19 Hello msvc LOG without arguments\n");
-    LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test");
-    LOGLN("21 Hello msvc LOGLN without arguments\n");
-    LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test");
-#endif
-}
-
-inline bool log_param_single_parse(const std::string & param)
-{
-    if ( param == "--log-test")
-    {
-        log_test();
-        return true;
-    }
-
-    if ( param == "--log-disable")
-    {
-        log_disable();
-        return true;
-    }
-
-    if ( param == "--log-enable")
-    {
-        log_enable();
-        return true;
-    }
-
-    if (param == "--log-new")
-    {
-        log_multilog(true);
-        return true;
-    }
-
-    if (param == "--log-append")
-    {
-        log_append(true);
-        return true;
-    }
-
-    return false;
-}
-
-inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string())
-{
-    if ( param == "--log-file")
-    {
-        if (!check_but_dont_parse)
-        {
-            log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log"));
-        }
-
-        return true;
-    }
-
-    return false;
-}
-
-inline void log_print_usage()
-{
-    printf("log options:\n");
-    /* format
-    printf("  -h, --help            show this help message and exit\n");*/
-    /* spacing
-    printf("__-param----------------Description\n");*/
-    printf("  --log-test            Run simple logging test\n");
-    printf("  --log-disable         Disable trace logs\n");
-    printf("  --log-enable          Enable trace logs\n");
-    printf("  --log-file            Specify a log filename (without extension)\n");
-    printf("  --log-new             Create a separate new log file on start. "
-                                   "Each log file will have unique name: \"..log\"\n");
-    printf("  --log-append          Don't truncate the old log file.\n");
-    printf("\n");
-}
-
-#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv)
-
-// INTERNAL, DO NOT USE
-inline void log_dump_cmdline_impl(int argc, char **argv)
-{
-    std::stringstream buf;
-    for (int i = 0; i < argc; ++i)
-    {
-        if (std::string(argv[i]).find(' ') != std::string::npos)
-        {
-            buf << " \"" << argv[i] <<"\"";
-        }
-        else
-        {
-            buf << " " << argv[i];
-        }
-    }
-    LOGLN("Cmd:%s", buf.str().c_str());
-}
-
-#define log_tostr(var) log_var_to_string_impl(var).c_str()
-
-inline std::string log_var_to_string_impl(bool var)
-{
-    return var ? "true" : "false";
-}
-
-inline std::string log_var_to_string_impl(std::string var)
-{
-    return var;
-}
-
-inline std::string log_var_to_string_impl(const std::vector & var)
-{
-    std::stringstream buf;
-    buf << "[ ";
-    bool first = true;
-    for (auto e : var)
-    {
-        if (first)
-        {
-            first = false;
-        }
-        else
-        {
-            buf << ", ";
-        }
-        buf << std::to_string(e);
-    }
-    buf << " ]";
-
-    return buf.str();
-}
-
-template 
-inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
-{
-    std::stringstream buf;
-    buf << "[ ";
-
-    bool first = true;
-    for (const auto & token : tokens)
-    {
-        if (!first) {
-            buf << ", ";
-        } else {
-            first = false;
-        }
-
-        auto detokenized = llama_token_to_piece(ctx, token);
-
-        detokenized.erase(
-            std::remove_if(
-                detokenized.begin(),
-                detokenized.end(),
-                [](const unsigned char c) { return !std::isprint(c); }),
-            detokenized.end());
-
-        buf
-            << "'" << detokenized << "'"
-            << ":" << std::to_string(token);
-    }
-    buf << " ]";
-
-    return buf.str();
-}
-
-template 
-inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
-{
-    std::stringstream buf;
-    buf << "[ ";
-
-    bool first = true;
-    for (int i = 0; i < batch.n_tokens; ++i)
-    {
-        if (!first) {
-            buf << ", ";
-        } else {
-            first = false;
-        }
-
-        auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
-
-        detokenized.erase(
-            std::remove_if(
-                detokenized.begin(),
-                detokenized.end(),
-                [](const unsigned char c) { return !std::isprint(c); }),
-            detokenized.end());
-
-        buf
-            << "\n" << std::to_string(i)
-            << ":token '" << detokenized << "'"
-            << ":pos " << std::to_string(batch.pos[i])
-            << ":n_seq_id  " << std::to_string(batch.n_seq_id[i])
-            << ":seq_id " << std::to_string(batch.seq_id[i][0])
-            << ":logits " << std::to_string(batch.logits[i]);
-    }
-    buf << " ]";
-
-    return buf.str();
-}
-
-#ifdef LOG_DISABLE_LOGS
-
-#undef LOG
-#define LOG(...) // dummy stub
-#undef LOGLN
-#define LOGLN(...) // dummy stub
-
-#undef LOG_TEE
-#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf
-
-#undef LOG_TEELN
-#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf
-
-#undef LOG_DISABLE
-#define LOG_DISABLE() // dummy stub
-
-#undef LOG_ENABLE
-#define LOG_ENABLE() // dummy stub
 
-#undef LOG_ENABLE
-#define LOG_ENABLE() // dummy stub
+#define LOG_TMPL(level, verbosity, ...) \
+    do { \
+        if ((verbosity) <= common_log_verbosity_thold) { \
+            common_log_add(common_log_main(), (level), __VA_ARGS__); \
+        } \
+    } while (0)
 
-#undef LOG_SET_TARGET
-#define LOG_SET_TARGET(...) // dummy stub
+#define LOG(...)             LOG_TMPL(LM_GGML_LOG_LEVEL_NONE, 0,         __VA_ARGS__)
+#define LOGV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
 
-#undef LOG_DUMP_CMDLINE
-#define LOG_DUMP_CMDLINE(...) // dummy stub
+#define LOG_INF(...) LOG_TMPL(LM_GGML_LOG_LEVEL_INFO,  0,                 __VA_ARGS__)
+#define LOG_WRN(...) LOG_TMPL(LM_GGML_LOG_LEVEL_WARN,  0,                 __VA_ARGS__)
+#define LOG_ERR(...) LOG_TMPL(LM_GGML_LOG_LEVEL_ERROR, 0,                 __VA_ARGS__)
+#define LOG_DBG(...) LOG_TMPL(LM_GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
+#define LOG_CNT(...) LOG_TMPL(LM_GGML_LOG_LEVEL_CONT,  0,                 __VA_ARGS__)
 
-#endif // LOG_DISABLE_LOGS
+#define LOG_INFV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_INFO,  verbosity, __VA_ARGS__)
+#define LOG_WRNV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_WARN,  verbosity, __VA_ARGS__)
+#define LOG_ERRV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
+#define LOG_DBGV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
+#define LOG_CNTV(verbosity, ...) LOG_TMPL(LM_GGML_LOG_LEVEL_CONT,  verbosity, __VA_ARGS__)
diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp
index 537f0a1..c8ca2ce 100644
--- a/cpp/rn-llama.hpp
+++ b/cpp/rn-llama.hpp
@@ -5,6 +5,7 @@
 #include 
 #include "common.h"
 #include "llama.h"
+#include "sampling.h"
 
 namespace rnllama {
 
@@ -115,7 +116,7 @@ static size_t find_partial_stop_string(const std::string &stop,
 // format incomplete utf-8 multibyte character for output
 static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
 {
-    std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
+    std::string out = token == -1 ? "" : common_token_to_piece(ctx, token);
     // if the size is 1 and first bit is 1, meaning it's a partial character
     //   (size > 1 meaning it's already a known token)
     if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -134,7 +135,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
     std::string ret;
     for (; begin != end; ++begin)
     {
-        ret += llama_token_to_piece(ctx, *begin);
+        ret += common_token_to_piece(ctx, *begin);
     }
     return ret;
 }
@@ -154,12 +155,12 @@ struct llama_rn_context
 
     std::vector embd;
 
-    gpt_params params;
+    common_params params;
 
     llama_model *model = nullptr;
     llama_context *ctx = nullptr;
-    llama_sampling_context *ctx_sampling = nullptr;
-
+    common_sampler *ctx_sampling = nullptr;
+  
     int n_ctx;
 
     bool truncated = false;
@@ -183,7 +184,7 @@ struct llama_rn_context
         }
         if (ctx_sampling != nullptr)
         {
-            llama_sampling_free(ctx_sampling);
+            common_sampler_free(ctx_sampling);
         }
     }
 
@@ -210,18 +211,18 @@ struct llama_rn_context
 
     bool initSampling() {
         if (ctx_sampling != nullptr) {
-            llama_sampling_free(ctx_sampling);
+            common_sampler_free(ctx_sampling);
         }
-        ctx_sampling = llama_sampling_init(params.sparams);
+        ctx_sampling = common_sampler_init(model, params.sparams);
         return ctx_sampling != nullptr;
     }
 
-    bool loadModel(gpt_params ¶ms_)
+    bool loadModel(common_params ¶ms_)
     {
         params = params_;
-        llama_init_result llama_init = llama_init_from_gpt_params(params);
-        model = llama_init.model;
-        ctx = llama_init.context;
+        common_init_result result = common_init_from_params(params);
+        model = result.model;
+        ctx = result.context;
         if (model == nullptr)
         {
            LOG_ERROR("unable to load model: %s", params_.model.c_str());
@@ -265,7 +266,7 @@ struct llama_rn_context
 
     void loadPrompt()
     {
-        std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true, true);
+        std::vector prompt_tokens = ::common_tokenize(ctx, params.prompt, true, true);
         num_prompt_tokens = prompt_tokens.size();
 
         // LOG tokens
@@ -293,7 +294,7 @@ struct llama_rn_context
         // push the prompt into the sampling context (do not apply grammar)
         for (auto & token : prompt_tokens)
         {
-           llama_sampling_accept(ctx_sampling, ctx, token, false);
+           common_sampler_accept(ctx_sampling, token, false);
         }
 
         // compare the evaluated prompt with the new prompt
@@ -322,8 +323,7 @@ struct llama_rn_context
     {
         // number of tokens to keep when resetting context
         n_remain = params.n_predict;
-        llama_set_rng_seed(ctx, params.seed);
-
+        llama_perf_context_reset(ctx);
         is_predicting = true;
     }
 
@@ -366,18 +366,25 @@ struct llama_rn_context
             {
                 n_eval = params.n_batch;
             }
-            if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
-            {
+            if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval)))
+            {   
                 LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
                     n_eval,
                     n_past,
-                    params.n_threads,
+                    params.cpuparams.n_threads,
                     tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
                 );
                 has_next_token = false;
                 return result;
             }
             n_past += n_eval;
+            
+            if(is_interrupted) {
+                LOG_INFO("Decoding Interrupted");
+                embd.resize(n_past);
+                has_next_token = false;
+                return result;
+            }
         }
 
         if (params.n_predict == 0)
@@ -392,22 +399,27 @@ struct llama_rn_context
             std::vector candidates;
             candidates.reserve(llama_n_vocab(model));
 
-            result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
-
-            llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
+            result.tok = common_sampler_sample(ctx_sampling, ctx, -1);
+            
+            llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);
 
             const int32_t n_probs = params.sparams.n_probs;
-            if (params.sparams.temp <= 0 && n_probs > 0)
+            
+            // deprecated
+            /*if (params.sparams.temp <= 0 && n_probs > 0)
             {
                 // For llama_sample_token_greedy we need to sort candidates
-                llama_sample_softmax(ctx, &cur_p);
-            }
+                llama_sampler_init_softmax();
+
+            }*/
+            
 
             for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
             {
                 result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
             }
-            llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
+
+            common_sampler_accept(ctx_sampling, result.tok, true);
             if (tg) {
                 num_tokens_predicted++;
             }
@@ -467,7 +479,7 @@ struct llama_rn_context
     {
         const completion_token_output token_with_probs = nextToken();
 
-        const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
+        const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(ctx, token_with_probs.tok);
         generated_text += token_text;
 
         if (params.sparams.n_probs > 0)
@@ -508,7 +520,7 @@ struct llama_rn_context
         }
 
         LOG_VERBOSE("next token, token: %s, token_text: %s, has_next_token: %d, n_remain: %d, num_tokens_predicted: %d, stopped_eos: %d, stopped_word: %d, stopped_limit: %d, stopping_word: %s",
-            llama_token_to_piece(ctx, token_with_probs.tok),
+            common_token_to_piece(ctx, token_with_probs.tok),
             tokens_to_output_formatted_string(ctx, token_with_probs.tok).c_str(),
             has_next_token,
             n_remain,
@@ -529,9 +541,21 @@ struct llama_rn_context
             LOG_WARNING("embedding disabled, embedding: %s", params.embedding);
             return std::vector(n_embd, 0.0f);
         }
-        const float *data = llama_get_embeddings(ctx);
-        std::vector embedding(data, data + n_embd);
-        return embedding;
+        float *data;
+        
+        if(params.pooling_type == 0){
+            data = llama_get_embeddings(ctx);
+        }
+        else {
+            data = llama_get_embeddings_seq(ctx, 0);
+        }
+        
+        if(!data) {
+            return std::vector(n_embd, 0.0f);
+        }
+        std::vector embedding(data, data + n_embd), out(data, data + n_embd);
+        common_embd_normalize(embedding.data(), out.data(), n_embd, params.embd_normalize);
+        return out;
     }
 
     std::string bench(int pp, int tg, int pl, int nr)
diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp
index e99bbae..66a2311 100644
--- a/cpp/sampling.cpp
+++ b/cpp/sampling.cpp
@@ -1,460 +1,466 @@
-#define LLAMA_API_INTERNAL
 #include "sampling.h"
-#include 
 
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
-    struct llama_sampling_context * result = new llama_sampling_context();
+#include "common.h"
 
-    result->params  = params;
-    result->grammar = nullptr;
+#include 
+#include 
 
-    // if there is a grammar, parse it
-    if (!params.grammar.empty()) {
-        result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+// TODO: deduplicate with llama-impl.h
+template
+struct ring_buffer {
+    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
 
-        // will be empty (default) if there are parse errors
-        if (result->parsed_grammar.rules.empty()) {
-            fprintf(stderr, "%s: failed to parse grammar\n", __func__);
-            delete result;
-            return nullptr;
+    T & front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
+        return data[first];
+    }
 
-        // Ensure that there is a "root" node.
-        if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
-            fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
-            delete result;
-            return nullptr;
+    const T & front() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
+        return data[first];
+    }
 
-        std::vector grammar_rules(result->parsed_grammar.c_rules());
-
-        struct llama_grammar * grammar = llama_grammar_init(
-                grammar_rules.data(),
-                grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
-        if (grammar == nullptr) {
-            throw std::runtime_error("Failed to initialize llama_grammar");
+    T & back() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
-        result->grammar = grammar;
+        return data[pos];
     }
 
-    result->prev.resize(params.n_prev);
-
-    result->n_valid = 0;
-
-    llama_sampling_set_rng_seed(result, params.seed);
-
-    return result;
-}
-
-void llama_sampling_free(struct llama_sampling_context * ctx) {
-    if (ctx->grammar != NULL) {
-        llama_grammar_free(ctx->grammar);
+    const T & back() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
     }
 
-    delete ctx;
-}
-
-void llama_sampling_reset(llama_sampling_context * ctx) {
-    if (ctx->grammar != NULL) {
-        llama_grammar_free(ctx->grammar);
-        ctx->grammar = NULL;
+    void push_back(const T & value) {
+        if (sz == capacity) {
+            // advance the start when buffer is full
+            first = (first + 1) % capacity;
+        } else {
+            sz++;
+        }
+        data[pos] = value;
+        pos = (pos + 1) % capacity;
     }
 
-    if (!ctx->parsed_grammar.rules.empty()) {
-        std::vector grammar_rules(ctx->parsed_grammar.c_rules());
+    T pop_front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        T value = data[first];
+        first = (first + 1) % capacity;
+        sz--;
+        return value;
+    }
 
-        struct llama_grammar * grammar = llama_grammar_init(
-                grammar_rules.data(),
-                grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
-        if (grammar == nullptr) {
-            throw std::runtime_error("Failed to initialize llama_grammar");
+    const T & rat(size_t i) const {
+        if (i >= sz) {
+            throw std::runtime_error("ring buffer: index out of bounds");
         }
-        ctx->grammar = grammar;
+        return data[(first + sz - i - 1) % capacity];
     }
 
-    std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
-    ctx->cur.clear();
-    ctx->n_valid = 0;
-}
+    std::vector to_vector() const {
+        std::vector result;
+        result.reserve(sz);
+        for (size_t i = 0; i < sz; i++) {
+            result.push_back(data[(first + i) % capacity]);
+        }
+        return result;
+    }
 
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = std::random_device{}();
+    void clear() {
+        // here only reset the status of the buffer
+        sz = 0;
+        first = 0;
+        pos = 0;
     }
-    ctx->rng.seed(seed);
-}
 
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
-    if (dst->grammar) {
-        llama_grammar_free(dst->grammar);
-        dst->grammar = nullptr;
+    bool empty() const {
+        return sz == 0;
     }
 
-    if (src->grammar) {
-        dst->grammar = llama_grammar_copy(src->grammar);
+    size_t size() const {
+        return sz;
     }
 
-    dst->prev = src->prev;
-}
+    size_t capacity = 0;
+    size_t sz = 0;
+    size_t first = 0;
+    size_t pos = 0;
+    std::vector data;
+};
 
-llama_token llama_sampling_last(llama_sampling_context * ctx) {
-    return ctx->prev.back();
-}
+struct common_sampler {
+    common_sampler_params params;
 
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
-    const int size = ctx_sampling->prev.size();
+    struct llama_sampler * grmr;
+    struct llama_sampler * chain;
 
-    n = std::min(n, size);
+    ring_buffer prev;
 
-    std::string result;
+    std::vector cur;
 
-    for (int i = size - n; i < size; i++) {
-        result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
-    }
+    llama_token_data_array cur_p;
 
-    return result;
-}
+    void set_logits(struct llama_context * ctx, int idx) {
+        const auto * logits = llama_get_logits_ith(ctx, idx);
+
+        const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+        cur.resize(n_vocab);
 
-std::string llama_sampling_print(const llama_sampling_params & params) {
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+        }
+
+        cur_p = { cur.data(), cur.size(), -1, false };
+    }
+};
+
+std::string common_sampler_params::print() const {
     char result[1024];
 
     snprintf(result, sizeof(result),
             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
-            "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
+            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
+            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
-            params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
-            params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
-            params.mirostat, params.mirostat_eta, params.mirostat_tau);
+            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
+            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
+            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
+            mirostat, mirostat_eta, mirostat_tau);
 
     return std::string(result);
 }
 
-std::string llama_sampling_order_print(const llama_sampling_params & params) {
-    std::string result = "CFG -> Penalties ";
+struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
+    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
+
+    lparams.no_perf = params.no_perf;
+
+    auto * result = new common_sampler {
+        /* .params = */ params,
+        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
+        /* .chain  = */ llama_sampler_chain_init(lparams),
+        /* .prev   = */ ring_buffer(std::max(32, params.n_prev)),
+        /* .cur    = */ {},
+        /* .cur_p  = */ {},
+    };
+
+    llama_sampler_chain_add(result->chain,
+            llama_sampler_init_logit_bias(
+                llama_n_vocab(model),
+                params.logit_bias.size(),
+                params.logit_bias.data()));
+
+    llama_sampler_chain_add(result->chain,
+            llama_sampler_init_penalties(
+                llama_n_vocab  (model),
+                llama_token_eos(model),
+                llama_token_nl (model),
+                params.penalty_last_n,
+                params.penalty_repeat,
+                params.penalty_freq,
+                params.penalty_present,
+                params.penalize_nl,
+                params.ignore_eos));
+
     if (params.mirostat == 0) {
-        for (auto sampler_type : params.samplers_sequence) {
-            const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
-            if (!sampler_type_name.empty()) {
-                result += "-> " + sampler_type_name + " ";
+        for (const auto & cnstr : params.samplers) {
+            switch (cnstr) {
+                    case COMMON_SAMPLER_TYPE_DRY:
+                    {
+                        std::vector c_breakers;
+                        c_breakers.reserve(params.dry_sequence_breakers.size());
+                        for (const auto& str : params.dry_sequence_breakers) {
+                            c_breakers.push_back(str.c_str());
+                        }
+
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
+                    }
+                        break;
+                case COMMON_SAMPLER_TYPE_TOP_K:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
+                    break;
+                case COMMON_SAMPLER_TYPE_TOP_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_MIN_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_XTC:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                    break;
+                case COMMON_SAMPLER_TYPE_TYPICAL_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_TEMPERATURE:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                    break;
+                case COMMON_SAMPLER_TYPE_INFILL:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
+                    break;
+                default:
+                    LM_GGML_ASSERT(false && "unknown sampler type");
             }
         }
+        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
+    } else if (params.mirostat == 1) {
+        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
+    } else if (params.mirostat == 2) {
+        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
     } else {
-        result += "-> mirostat ";
+        LM_GGML_ASSERT(false && "unknown mirostat version");
     }
 
     return result;
 }
 
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
-    switch (sampler_type) {
-        case llama_sampler_type::TOP_K:       return "top_k";
-        case llama_sampler_type::TFS_Z:       return "tfs_z";
-        case llama_sampler_type::TYPICAL_P:   return "typical_p";
-        case llama_sampler_type::TOP_P:       return "top_p";
-        case llama_sampler_type::MIN_P:       return "min_p";
-        case llama_sampler_type::TEMPERATURE: return "temperature";
-        default : return "";
+void common_sampler_free(struct common_sampler * gsmpl) {
+    if (gsmpl) {
+        llama_sampler_free(gsmpl->grmr);
+
+        llama_sampler_free(gsmpl->chain);
+
+        delete gsmpl;
     }
 }
 
-std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) {
-    std::unordered_map sampler_canonical_name_map {
-        {"top_k",       llama_sampler_type::TOP_K},
-        {"top_p",       llama_sampler_type::TOP_P},
-        {"typical_p",   llama_sampler_type::TYPICAL_P},
-        {"min_p",       llama_sampler_type::MIN_P},
-        {"tfs_z",       llama_sampler_type::TFS_Z},
-        {"temperature", llama_sampler_type::TEMPERATURE}
-    };
+void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
+    if (accept_grammar) {
+        llama_sampler_accept(gsmpl->grmr, token);
+    }
 
-    // since samplers names are written multiple ways
-    // make it ready for both system names and input names
-    std::unordered_map sampler_alt_name_map {
-        {"top-k",       llama_sampler_type::TOP_K},
-        {"top-p",       llama_sampler_type::TOP_P},
-        {"nucleus",     llama_sampler_type::TOP_P},
-        {"typical-p",   llama_sampler_type::TYPICAL_P},
-        {"typical",     llama_sampler_type::TYPICAL_P},
-        {"min-p",       llama_sampler_type::MIN_P},
-        {"tfs-z",       llama_sampler_type::TFS_Z},
-        {"tfs",         llama_sampler_type::TFS_Z},
-        {"temp",        llama_sampler_type::TEMPERATURE}
-    };
+    llama_sampler_accept(gsmpl->chain, token);
 
-    std::vector sampler_types;
-    sampler_types.reserve(names.size());
-    for (const auto & name : names)
-    {
-        auto sampler_item = sampler_canonical_name_map.find(name);
-        if (sampler_item != sampler_canonical_name_map.end())
-        {
-            sampler_types.push_back(sampler_item->second);
-        }
-        else
-        {
-            if (allow_alt_names)
-            {
-                sampler_item = sampler_alt_name_map.find(name);
-                if (sampler_item != sampler_alt_name_map.end())
-                {
-                    sampler_types.push_back(sampler_item->second);
-                }
-            }
-        }
-    }
-    return sampler_types;
+    gsmpl->prev.push_back(token);
 }
 
-std::vector llama_sampling_types_from_chars(const std::string & names_string) {
-    std::unordered_map sampler_name_map {
-        {'k', llama_sampler_type::TOP_K},
-        {'p', llama_sampler_type::TOP_P},
-        {'y', llama_sampler_type::TYPICAL_P},
-        {'m', llama_sampler_type::MIN_P},
-        {'f', llama_sampler_type::TFS_Z},
-        {'t', llama_sampler_type::TEMPERATURE}
-    };
+void common_sampler_reset(struct common_sampler * gsmpl) {
+    llama_sampler_reset(gsmpl->grmr);
 
-    std::vector sampler_types;
-    sampler_types.reserve(names_string.size());
-    for (const auto & c : names_string) {
-        const auto sampler_item = sampler_name_map.find(c);
-        if (sampler_item != sampler_name_map.end()) {
-            sampler_types.push_back(sampler_item->second);
-        }
-    }
-    return sampler_types;
+    llama_sampler_reset(gsmpl->chain);
 }
 
-// no reasons to expose this function in header
-static void sampler_queue(
-                   struct llama_context * ctx_main,
-            const llama_sampling_params & params,
-                 llama_token_data_array & cur_p,
-                                 size_t   min_keep) {
-    const float         temp              = params.temp;
-    const float         dynatemp_range    = params.dynatemp_range;
-    const float         dynatemp_exponent = params.dynatemp_exponent;
-    const int32_t       top_k             = params.top_k;
-    const float         top_p             = params.top_p;
-    const float         min_p             = params.min_p;
-    const float         tfs_z             = params.tfs_z;
-    const float         typical_p         = params.typical_p;
-    const std::vector & samplers_sequence = params.samplers_sequence;
-
-    for (auto sampler_type : samplers_sequence) {
-        switch (sampler_type) {
-            case llama_sampler_type::TOP_K    : llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
-            case llama_sampler_type::TFS_Z    : llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
-            case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
-            case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
-            case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
-            case llama_sampler_type::TEMPERATURE:
-                if (dynatemp_range > 0) {
-                    float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
-                    float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
-                    llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
-                } else {
-                    llama_sample_temp(ctx_main, &cur_p, temp);
-                }
-                break;
-            default : break;
-        }
-    }
+struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
+    return new common_sampler {
+        /* .params = */ gsmpl->params,
+        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
+        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
+        /* .prev   = */ gsmpl->prev,
+        /* .cur    = */ gsmpl->cur,
+        /* .cur_p  = */ gsmpl->cur_p,
+    };
 }
 
-static llama_token llama_sampling_sample_impl(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool is_resampling) {
-    const llama_sampling_params & params = ctx_sampling->params;
-
-    const float   temp            = params.temp;
-    const int     mirostat        = params.mirostat;
-    const float   mirostat_tau    = params.mirostat_tau;
-    const float   mirostat_eta    = params.mirostat_eta;
-
-    std::vector original_logits;
-    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
-    if (ctx_sampling->grammar != NULL && !is_resampling) {
-        LM_GGML_ASSERT(!original_logits.empty());
-    }
-    llama_token id = 0;
-
-    if (temp < 0.0) {
-        // greedy sampling, with probs
-        llama_sample_softmax(ctx_main, &cur_p);
-        id = cur_p.data[0].id;
-    } else if (temp == 0.0) {
-        // greedy sampling, no probs
-        id = llama_sample_token_greedy(ctx_main, &cur_p);
-    } else {
-        if (mirostat == 1) {
-            const int mirostat_m = 100;
-            llama_sample_temp(ctx_main, &cur_p, temp);
-            id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
-        } else if (mirostat == 2) {
-            llama_sample_temp(ctx_main, &cur_p, temp);
-            id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
-        } else {
-            // temperature sampling
-            size_t min_keep = std::max(1, params.min_keep);
-
-            sampler_queue(ctx_main, params, cur_p, min_keep);
+void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
+    // TODO: measure grammar performance
 
-            id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
+    if (gsmpl) {
+        llama_perf_sampler_print(gsmpl->chain);
+    }
+    if (ctx) {
+        llama_perf_context_print(ctx);
+    }
+}
 
-            //{
-            //    const int n_top = 10;
-            //    LOG("top %d candidates:\n", n_top);
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
+    gsmpl->set_logits(ctx, idx);
 
-            //    for (int i = 0; i < n_top; i++) {
-            //        const llama_token id = cur_p.data[i].id;
-            //        (void)id; // To avoid a warning that id is unused when logging is disabled.
-            //        LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
-            //    }
-            //}
+    auto & grmr  = gsmpl->grmr;
+    auto & chain = gsmpl->chain;
+    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
 
-            //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
-        }
+    if (grammar_first) {
+        llama_sampler_apply(grmr, &cur_p);
     }
 
-    if (ctx_sampling->grammar != NULL && !is_resampling) {
-        // Get a pointer to the logits
-        float * logits = llama_get_logits_ith(ctx_main, idx);
+    llama_sampler_apply(chain, &cur_p);
 
-        // Create an array with a single token data element for the sampled id
-        llama_token_data single_token_data = {id, logits[id], 0.0f};
-        llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
+    LM_GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
 
-        // Apply grammar constraints to the single token
-        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
+    const llama_token id = cur_p.data[cur_p.selected].id;
 
-        // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
-        bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+    if (grammar_first) {
+        return id;
+    }
 
-        // If the token is not valid according to the grammar, perform resampling
-        if (!is_valid) {
-            LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
+    // check if it the sampled token fits the grammar
+    {
+        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
+        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
 
-            // Restore logits from the copy
-            std::copy(original_logits.begin(), original_logits.end(), logits);
+        llama_sampler_apply(grmr, &single_token_data_array);
 
-            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
+        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+        if (is_valid) {
+            return id;
         }
     }
 
-    ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
+    // resampling:
+    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+    gsmpl->set_logits(ctx, idx);
 
-    return id;
-}
+    llama_sampler_apply(grmr,  &cur_p);
+    llama_sampler_apply(chain, &cur_p);
 
-static llama_token_data_array llama_sampling_prepare_impl(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool apply_grammar,
-                  std::vector * original_logits) {
-    const llama_sampling_params & params = ctx_sampling->params;
+    LM_GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+    return cur_p.data[cur_p.selected].id;
+}
 
-    const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
-    const float   penalty_repeat  = params.penalty_repeat;
-    const float   penalty_freq    = params.penalty_freq;
-    const float   penalty_present = params.penalty_present;
+uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
+    return llama_sampler_get_seed(gsmpl->chain);
+}
 
-    const bool    penalize_nl     = params.penalize_nl;
+// helpers
 
-    auto & prev = ctx_sampling->prev;
-    auto & cur  = ctx_sampling->cur;
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
+    return &gsmpl->cur_p;
+}
 
-    // Get a pointer to the logits
-    float * logits = llama_get_logits_ith(ctx_main, idx);
+llama_token common_sampler_last(const struct common_sampler * gsmpl) {
+    return gsmpl->prev.rat(0);
+}
 
-    if (ctx_sampling->grammar != NULL && !apply_grammar) {
-        LM_GGML_ASSERT(original_logits != NULL);
-        // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
-        *original_logits = {logits, logits + n_vocab};
+std::string common_sampler_print(const struct common_sampler * gsmpl) {
+    std::string result = "logits ";
+
+    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
+        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
+        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
     }
 
-    // apply params.logit_bias map
-    for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
-        logits[it->first] += it->second;
+    return result;
+}
+
+std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
+    n = std::min(n, (int) gsmpl->prev.size());
+
+    if (n <= 0) {
+        return "";
     }
 
-    if (ctx_cfg) {
-        float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
-        llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+    std::string result;
+    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
+
+    for (int i = n - 1; i >= 0; i--) {
+        const llama_token id = gsmpl->prev.rat(i);
+
+        LM_GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
+
+        result += common_token_to_piece(ctx_main, id);
     }
 
-    cur.resize(n_vocab);
+    return result;
+}
 
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
+    switch (cnstr) {
+        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
+        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
+        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
+        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
+        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
+        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
+        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
+        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
+        default : return '?';
     }
+}
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
+    switch (cnstr) {
+        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
+        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
+        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
+        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
+        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
+        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
+        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
+        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
+        default : return "";
+    }
+}
 
-    // apply penalties
-    const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
-    const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
-    if (penalty_tokens_used_size) {
-        const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names) {
+    std::unordered_map sampler_canonical_name_map {
+        { "dry",         COMMON_SAMPLER_TYPE_DRY },
+        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
+        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
+        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
+        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
+        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
+        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
+    };
 
-        llama_sample_repetition_penalties(ctx_main, &cur_p,
-                penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
-                penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+    // since samplers names are written multiple ways
+    // make it ready for both system names and input names
+    std::unordered_map sampler_alt_name_map {
+        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
+        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
+        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
+        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
+        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
+    };
 
-        if (!penalize_nl) {
-            for (size_t idx = 0; idx < cur_p.size; idx++) {
-                if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
-                    cur_p.data[idx].logit = nl_logit;
-                    break;
+    std::vector samplers;
+    samplers.reserve(names.size());
+
+    for (const auto & name : names) {
+        auto sampler = sampler_canonical_name_map.find(name);
+        if (sampler != sampler_canonical_name_map.end()) {
+            samplers.push_back(sampler->second);
+        } else {
+            if (allow_alt_names) {
+                sampler = sampler_alt_name_map.find(name);
+                if (sampler != sampler_alt_name_map.end()) {
+                    samplers.push_back(sampler->second);
                 }
             }
         }
     }
 
-    // apply grammar checks before sampling logic
-    if (apply_grammar && ctx_sampling->grammar != NULL) {
-        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
-    }
-
-    return cur_p;
+    return samplers;
 }
 
-llama_token llama_sampling_sample(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx) {
-    // Call the implementation function with is_resampling set to false by default
-    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
-}
-
-llama_token_data_array llama_sampling_prepare(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool apply_grammar,
-                  std::vector * original_logits) {
-    return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
-}
+std::vector common_sampler_types_from_chars(const std::string & chars) {
+    std::unordered_map sampler_name_map = {
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
+    };
 
-void llama_sampling_accept(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        llama_token id,
-        bool apply_grammar) {
-    ctx_sampling->prev.erase(ctx_sampling->prev.begin());
-    ctx_sampling->prev.push_back(id);
+    std::vector samplers;
+    samplers.reserve(chars.size());
 
-    if (ctx_sampling->grammar != NULL && apply_grammar) {
-        llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
+    for (const auto & c : chars) {
+        const auto sampler = sampler_name_map.find(c);
+        if (sampler != sampler_name_map.end()) {
+            samplers.push_back(sampler->second);
+        }
     }
+
+    return samplers;
 }
diff --git a/cpp/sampling.h b/cpp/sampling.h
index eeaa53b..d37f25a 100644
--- a/cpp/sampling.h
+++ b/cpp/sampling.h
@@ -2,159 +2,82 @@
 
 #include "llama.h"
 
-#include "grammar-parser.h"
+#include "common.h"
 
-#include 
 #include 
-#include 
 #include 
 
-// sampler types
-enum class llama_sampler_type : char {
-    TOP_K       = 'k',
-    TOP_P       = 'p',
-    MIN_P       = 'm',
-    TFS_Z       = 'f',
-    TYPICAL_P   = 'y',
-    TEMPERATURE = 't'
-};
-
-// sampling parameters
-typedef struct llama_sampling_params {
-    int32_t     n_prev                = 64;                 // number of previous tokens to remember
-    int32_t     n_probs               = 0;                  // if greater than 0, output the probabilities of top n_probs tokens.
-    int32_t     min_keep              = 0;                  // 0 = disabled, otherwise samplers should return at least min_keep tokens
-    int32_t     top_k                 = 40;                 // <= 0 to use vocab size
-    float       top_p                 = 0.95f;              // 1.0 = disabled
-    float       min_p                 = 0.05f;              // 0.0 = disabled
-    float       tfs_z                 = 1.00f;              // 1.0 = disabled
-    float       typical_p             = 1.00f;              // 1.0 = disabled
-    float       temp                  = 0.80f;              // <= 0.0 to sample greedily, 0.0 to not output probabilities
-    float       dynatemp_range        = 0.00f;              // 0.0 = disabled
-    float       dynatemp_exponent     = 1.00f;              // controls how entropy maps to temperature in dynamic temperature sampler
-    int32_t     penalty_last_n        = 64;                 // last n tokens to penalize (0 = disable penalty, -1 = context size)
-    float       penalty_repeat        = 1.00f;              // 1.0 = disabled
-    float       penalty_freq          = 0.00f;              // 0.0 = disabled
-    float       penalty_present       = 0.00f;              // 0.0 = disabled
-    int32_t     mirostat              = 0;                  // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
-    float       mirostat_tau          = 5.00f;              // target entropy
-    float       mirostat_eta          = 0.10f;              // learning rate
-    bool        penalize_nl           = false;              // consider newlines as a repeatable token
-    uint32_t    seed                  = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
-
-    std::vector samplers_sequence = {
-        llama_sampler_type::TOP_K,
-        llama_sampler_type::TFS_Z,
-        llama_sampler_type::TYPICAL_P,
-        llama_sampler_type::TOP_P,
-        llama_sampler_type::MIN_P,
-        llama_sampler_type::TEMPERATURE
-    };
-
-    std::string grammar;  // optional BNF-like grammar to constrain sampling
-
-    // Classifier-Free Guidance
-    // https://arxiv.org/abs/2306.17806
-    std::string cfg_negative_prompt; // string to help guidance
-    float       cfg_scale     = 1.f; // how strong is guidance
-
-    std::unordered_map logit_bias; // logit bias for specific tokens
-
-    std::vector penalty_prompt_tokens;
-    bool                     use_penalty_prompt_tokens = false;
-} llama_sampling_params;
-
-// general sampler context
-// TODO: move to llama.h
-struct llama_sampling_context {
-    // parameters that will be used for sampling
-    llama_sampling_params params;
-
-    // mirostat sampler state
-    float mirostat_mu;
-
-    llama_grammar * grammar;
-
-    // internal
-    grammar_parser::parse_state parsed_grammar;
-
-    // TODO: replace with ring-buffer
-    std::vector      prev;
-    std::vector cur;
-    size_t n_valid; // Number of correct top tokens with correct probabilities.
-
-    std::mt19937 rng;
-};
+// common_sampler extends llama_sampler with additional functionality:
+//
+//  - grammar support
+//  - custom sampler logic based on the parameters
+//  - history of the last accepted tokens
+//  - performance metrics
+//
+// This goal is to have a common implementation of the sampling logic shared across the examples.
+// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
+// complex (top-k, top-p, etc).
+//
+// Another example is related to the grammar. In general, the grammar constraints applied on the full
+// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
+// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
+// grammar constraints are applied to the full vocabulary and the token is resampled.
+//
+// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
+// be moved into the core llama library.
+//
+// For convenience, the common_sampler also maintains a container with the current candidate tokens.
+// This can be used to access the probabilities of the rest of the non-sampled tokens.
+//
+// TODO: measure grammar performance
+//
 
-#include "common.h"
+struct common_sampler;
 
-// Create a new sampling context instance.
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
+// llama_sampler API overloads
 
-void llama_sampling_free(struct llama_sampling_context * ctx);
+struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
 
-// Reset the sampler context
-// - clear prev tokens
-// - reset grammar
-void llama_sampling_reset(llama_sampling_context * ctx);
+void common_sampler_free(struct common_sampler * gsmpl);
 
-// Set the sampler seed
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
+// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
+void                    common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
+void                    common_sampler_reset (struct common_sampler * gsmpl);
+struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
 
-// Copy the sampler context
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
+// arguments can be nullptr to skip printing
+void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
 
-// Get the last sampled token
-llama_token llama_sampling_last(llama_sampling_context * ctx);
+// extended sampling implementation:
+//
+// - set logits
+// - apply the configured sampler chain
+// - check if the token fits the grammar (if any)
+// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
+//
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 
-// Get a string representation of the last sampled tokens
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
+uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 
-// Print sampling parameters into a string
-std::string llama_sampling_print(const llama_sampling_params & params);
+// helpers
 
-// Print sampling order into a string
-std::string llama_sampling_order_print(const llama_sampling_params & params);
+// access the internal list of current candidate tokens
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
 
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
+// get the last accepted token
+llama_token common_sampler_last(const struct common_sampler * gsmpl);
 
-std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names);
-std::vector llama_sampling_types_from_chars(const std::string & names_string);
+// print the sampler chain into a string
+std::string common_sampler_print(const struct common_sampler * gsmpl);
 
-// this is a common sampling function used across the examples for convenience
-// it can serve as a starting point for implementing your own sampling function
-// Note: When using multiple sequences, it is the caller's responsibility to call
-//       llama_sampling_reset when a sequence ends
-//
-// required:
-//  - ctx_main:     context to use for sampling
-//  - ctx_sampling: sampling-specific context
-//
-// optional:
-//  - ctx_cfg:      context to use for classifier-free guidance
-//  - idx:          sample from llama_get_logits_ith(ctx, idx)
-//
-// returns:
-//  - token:      sampled token
-//  - candidates: vector of candidate tokens
-//
-llama_token llama_sampling_sample(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        struct llama_context * ctx_cfg,
-        int idx = -1);
-
-// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
-llama_token_data_array llama_sampling_prepare(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        struct llama_context * ctx_cfg,
-        int idx = 0,
-        bool apply_grammar = true,
-        std::vector * original_logits = nullptr);
-
-void llama_sampling_accept(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        llama_token id,
-        bool apply_grammar);
+// get a string representation of the last accepted tokens
+std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
+
+char        common_sampler_type_to_chr(enum common_sampler_type cnstr);
+std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
+
+std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names);
+std::vector common_sampler_types_from_chars(const std::string & chars);
diff --git a/cpp/sgemm.cpp b/cpp/sgemm.cpp
index 0205fd9..41dbb45 100644
--- a/cpp/sgemm.cpp
+++ b/cpp/sgemm.cpp
@@ -50,6 +50,7 @@
 
 #include "sgemm.h"
 #include "ggml-impl.h"
+#include "ggml-cpu-impl.h"
 #include "ggml-quants.h"
 
 #ifdef _MSC_VER
@@ -235,6 +236,14 @@ template <> inline __m512 load(const lm_ggml_fp16_t *p) {
 }
 #endif // __AVX512F__
 
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
+#endif
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // FLOATING POINT MATRIX MULTIPLICATION
 
@@ -606,17 +615,29 @@ class tinyBLAS_Q0_AVX {
         case 0x44:
             mc = 4;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<4>(m0, m, n0, n);
+#else
             gemm<4, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x43:
             mc = 4;
             nc = 3;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<3>(m0, m, n0, n);
+#else
             gemm<4, 3>(m0, m, n0, n);
+#endif
             break;
         case 0x34:
             mc = 3;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<3>(m0, m, n0, n);
+#else
             gemm<3, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x33:
             mc = 3;
@@ -626,12 +647,20 @@ class tinyBLAS_Q0_AVX {
         case 0x42:
             mc = 4;
             nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<2>(m0, m, n0, n);
+#else
             gemm<4, 2>(m0, m, n0, n);
+#endif
             break;
         case 0x24:
             mc = 2;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<2>(m0, m, n0, n);
+#else
             gemm<2, 4>(m0, m, n0, n);
+#endif
             break;
 #else
         case 0x44:
@@ -639,13 +668,21 @@ class tinyBLAS_Q0_AVX {
         case 0x42:
             mc = 4;
             nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<2>(m0, m, n0, n);
+#else
             gemm<4, 2>(m0, m, n0, n);
+#endif
             break;
         case 0x34:
         case 0x24:
             mc = 2;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<2>(m0, m, n0, n);
+#else
             gemm<2, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x33:
 #endif
@@ -662,7 +699,11 @@ class tinyBLAS_Q0_AVX {
         case 0x41:
             mc = 4;
             nc = 1;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<1>(m0, m, n0, n);
+#else
             gemm<4, 1>(m0, m, n0, n);
+#endif
             break;
         case 0x22:
             mc = 2;
@@ -672,7 +713,11 @@ class tinyBLAS_Q0_AVX {
         case 0x14:
             mc = 1;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<1>(m0, m, n0, n);
+#else
             gemm<1, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x31:
             mc = 3;
@@ -708,6 +753,119 @@ class tinyBLAS_Q0_AVX {
         mnpack(m0, m, np, n);
     }
 
+#if defined(__AVX2__) && defined(__F16C__)
+// Templated functions for gemm of dimensions 4xN
+    template 
+    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / 4;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * 4;
+            int64_t jj = n0 + job % xtiles * RN;
+            __m256 Cv[RN][4] = {};
+            for (int64_t l = 0; l < k; ++l) {
+                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
+                // Convert delta values for four blocks to float values
+                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
+                __m256i avec0 = load(A + lda * (ii + 0) + l);
+                __m256i avec1 = load(A + lda * (ii + 1) + l);
+                __m256i avec2 = load(A + lda * (ii + 2) + l);
+                __m256i avec3 = load(A + lda * (ii + 3) + l);
+                for (int64_t j = 0; j < RN; ++j) {
+                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
+                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
+                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+                        // Computation of dot product and multiplication with appropriate delta value products
+                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+                                    updot(_mm256_sign_epi8(avec0, avec0),
+                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
+                                    Cv[j][0]);
+                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+                                    updot(_mm256_sign_epi8(avec1, avec1),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
+                                    Cv[j][1]);
+                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+                                    updot(_mm256_sign_epi8(avec2, avec2),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
+                                    Cv[j][2]);
+                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+                                    updot(_mm256_sign_epi8(avec3, avec3),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
+                                    Cv[j][3]);
+                }
+            }
+
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < 4; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+
+    // Templated functions for gemm of dimensions Mx4
+    template 
+    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / 4;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * 4;
+            __m256 Cv[4][RM] = {};
+            for (int64_t l = 0; l < k; ++l) {
+                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
+                // Convert delta values for four blocks to float values
+                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
+                __m256i bvec0 = load(B + ldb * (jj + 0) + l);
+                __m256i bvec1 = load(B + ldb * (jj + 1) + l);
+                __m256i bvec2 = load(B + ldb * (jj + 2) + l);
+                __m256i bvec3 = load(B + ldb * (jj + 3) + l);
+                for (int64_t i = 0; i < RM; ++i) {
+                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
+                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
+                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+                    // Computation of dot product and multiplication with appropriate delta value products
+                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
+                                    Cv[0][i]);
+                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
+                                    Cv[1][i]);
+                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
+                                    Cv[2][i]);
+                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
+                                    Cv[3][i]);
+                }
+            }
+            for (int64_t j = 0; j < 4; ++j)
+                for (int64_t i = 0; i < RM; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+#endif
+
     template 
     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int64_t ytiles = (m - m0) / RM;
@@ -784,6 +942,50 @@ class tinyBLAS_Q0_AVX {
         return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
     }
 
+    inline __m256i load(const block_q5_0 *b) {
+        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
+    }
+
+    inline __m128i load0(const block_q5_0* b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        uint32_t x32;
+        memcpy(&x32, b->qh, sizeof(uint32_t));
+        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
+        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
+                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
+                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
+        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
+        return _mm_or_si128(qxl, bytesl);
+    }
+
+    inline __m128i load1(const block_q5_0* b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        uint32_t x32;
+        memcpy(&x32, b->qh, sizeof(uint32_t));
+        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
+        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
+                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
+                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
+        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
+        return _mm_or_si128(qxh, bytesh);
+    }
+
+    inline __m256i load(const block_iq4_nl *b) {
+        return MM256_SET_M128I(load1(b), load0(b));
+    }
+
+    inline __m128i load0(const block_iq4_nl *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
+    }
+
+    inline __m128i load1(const block_iq4_nl *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
+    }
+
     inline __m256 updot(__m256i u, __m256i s) {
         __m256i res;
 #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -801,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
                                                         _mm_srli_epi16(x, 4), 1));
     }
 
+    static inline __m256i bittobyte(const uint8_t *p) {
+        uint32_t x32;
+        memcpy(&x32, p, sizeof(uint32_t));
+        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
+                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
+                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
+                                                                                                0x0101010101010101, 0x0000000000000000))));
+        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
+    }
+
     const TA *const A;
     const TB *const B;
     TC *const C;
@@ -857,6 +1070,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     assert(nth > 0);
     assert(ith < nth);
 
+    // only enable sgemm for prompt processing
+    if (n < 2)
+        return false;
+
     if (Ctype != LM_GGML_TYPE_F32)
         return false;
 
@@ -1006,6 +1223,38 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
 #endif
     }
 
+    case LM_GGML_TYPE_Q5_0: {
+        if (Btype != LM_GGML_TYPE_Q8_0)
+            return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX tb{
+            k, (const block_q5_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
+    case LM_GGML_TYPE_IQ4_NL: {
+        if (Btype != LM_GGML_TYPE_Q8_0)
+            return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX tb{
+            k, (const block_iq4_nl *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
     default:
         return false;
     }
diff --git a/cpp/unicode-data.cpp b/cpp/unicode-data.cpp
index 02bdf78..04dcd7f 100644
--- a/cpp/unicode-data.cpp
+++ b/cpp/unicode-data.cpp
@@ -7,7 +7,7 @@
 #include 
 #include 
 
-const std::vector> unicode_ranges_flags = {  // start, flags // last=next_start-1
+const std::initializer_list> unicode_ranges_flags = {  // start, flags // last=next_start-1
 {0x000000, 0x0080},
 {0x000020, 0x0008},
 {0x000021, 0x0020},
@@ -2311,7 +2311,8 @@ const std::unordered_set unicode_set_whitespace = {
 0x003000,
 };
 
-const std::unordered_map unicode_map_lowercase = {
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_lowercase = {
 {0x000041, 0x000061},
 {0x000042, 0x000062},
 {0x000043, 0x000063},
@@ -3747,7 +3748,8 @@ const std::unordered_map unicode_map_lowercase = {
 {0x01E921, 0x01E943},
 };
 
-const std::unordered_map unicode_map_uppercase = {
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_uppercase = {
 {0x000061, 0x000041},
 {0x000062, 0x000042},
 {0x000063, 0x000043},
@@ -5200,7 +5202,7 @@ const std::unordered_map unicode_map_uppercase = {
 {0x01E943, 0x01E921},
 };
 
-const std::vector unicode_ranges_nfd = {  // start, last, nfd
+const std::initializer_list unicode_ranges_nfd = {  // start, last, nfd
 {0x000000, 0x000000, 0x000000},
 {0x0000C0, 0x0000C5, 0x000041},
 {0x0000C7, 0x0000C7, 0x000043},
diff --git a/cpp/unicode-data.h b/cpp/unicode-data.h
index e27fe17..f6973eb 100644
--- a/cpp/unicode-data.h
+++ b/cpp/unicode-data.h
@@ -13,8 +13,8 @@ struct range_nfd {
 
 static const uint32_t MAX_CODEPOINTS = 0x110000;
 
-extern const std::vector> unicode_ranges_flags;
+extern const std::initializer_list> unicode_ranges_flags;
 extern const std::unordered_set unicode_set_whitespace;
-extern const std::unordered_map unicode_map_lowercase;
-extern const std::unordered_map unicode_map_uppercase;
-extern const std::vector unicode_ranges_nfd;
+extern const std::initializer_list> unicode_map_lowercase;
+extern const std::initializer_list> unicode_map_uppercase;
+extern const std::initializer_list unicode_ranges_nfd;
diff --git a/cpp/unicode.cpp b/cpp/unicode.cpp
index 46650bf..50b35bb 100644
--- a/cpp/unicode.cpp
+++ b/cpp/unicode.cpp
@@ -5,6 +5,7 @@
 #include "unicode.h"
 #include "unicode-data.h"
 
+#include 
 #include 
 #include 
 #include 
@@ -122,11 +123,11 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
 static std::vector unicode_cpt_flags_array() {
     std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
 
-    assert (unicode_ranges_flags.front().first == 0);
-    assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
+    assert (unicode_ranges_flags.begin()[0].first == 0);
+    assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
     for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
-        const auto range_ini = unicode_ranges_flags[i-1];  // codepoint_ini, flags
-        const auto range_end = unicode_ranges_flags[i];    // codepoint_end, flags
+        const auto range_ini = unicode_ranges_flags.begin()[i-1];  // codepoint_ini, flags
+        const auto range_end = unicode_ranges_flags.begin()[i];    // codepoint_end, flags
         for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
             cpt_flags[cpt] = range_ini.second;
         }
@@ -596,7 +597,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c
     std::vector result(cpts.size());
     for (size_t i = 0; i < cpts.size(); ++i) {
         const uint32_t cpt = cpts[i];
-        auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
+        auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
         result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
     }
     return result;
@@ -638,8 +639,15 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
 }
 
 uint32_t unicode_tolower(uint32_t cp) {
-    auto it = unicode_map_lowercase.find(cp);
-    return it == unicode_map_lowercase.end() ? cp : it->second;
+    // binary search
+    auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp,
+        [](const std::pair & pair, uint32_t value) {
+            return pair.first < value;
+        });
+    if (it != unicode_map_lowercase.end() && it->first == cp) {
+        return it->second;
+    }
+    return cp;  // Return the original code point if no lowercase mapping is found
 }
 
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local
index 51de392..289c3d0 100644
--- a/example/ios/.xcode.env.local
+++ b/example/ios/.xcode.env.local
@@ -1 +1 @@
-export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1722073570606-0.6759511337227031/node
+export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730514789911-0.16979892623603998/node
diff --git a/example/src/App.tsx b/example/src/App.tsx
index 069511a..7f5959a 100644
--- a/example/src/App.tsx
+++ b/example/src/App.tsx
@@ -268,7 +268,6 @@ export default function App() {
     ]
     addMessage(textMessage)
     setInferencing(true)
-
     // Test area
     {
       // Test tokenize
@@ -348,11 +347,12 @@ export default function App() {
       ?.completion(
         {
           messages: msgs,
-          n_predict: 400,
+          n_predict: 100,
+          xtc_probability: 0.5,
+          xtc_threshold: 0.1,
           temperature: 0.7,
           top_k: 40, // <= 0 to use vocab size
           top_p: 0.5, // 1.0 = disabled
-          tfs_z: 1.0, // 1.0 = disabled
           typical_p: 1.0, // 1.0 = disabled
           penalty_last_n: 256, // 0 = disable penalty, -1 = context size
           penalty_repeat: 1.18, // 1.0 = disabled
diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm
index f7fbc95..b2f959d 100644
--- a/ios/RNLlamaContext.mm
+++ b/ios/RNLlamaContext.mm
@@ -5,7 +5,12 @@ @implementation RNLlamaContext
 
 + (instancetype)initWithParams:(NSDictionary *)params {
     // llama_backend_init(false);
-    gpt_params defaultParams;
+    common_params defaultParams;
+
+    if (params[@"vocab_only"]) {
+        defaultParams.vocab_only = [params[@"vocab_only"] boolValue];
+        defaultParams.warmup = false;
+    }
 
     NSString *modelPath = params[@"model"];
     BOOL isAsset = [params[@"is_model_asset"] boolValue];
@@ -66,13 +71,11 @@ + (instancetype)initWithParams:(NSDictionary *)params {
     if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue];
     if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue];
 
-    if (params[@"seed"]) defaultParams.seed = [params[@"seed"] intValue];
-
     int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : 0;
     const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount];
     // Use 2 threads by default on 4-core devices, 4 threads on more cores
     const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
-    defaultParams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
+    defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
 
     RNLlamaContext *context = [[RNLlamaContext alloc] init];
     if (context->llama == nullptr) {
@@ -128,7 +131,7 @@ - (bool)isPredicting {
 }
 
 - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chatTemplate {
-  std::vector chat;
+  std::vector chat;
 
   for (NSDictionary *msg in messages) {
     std::string role = [[msg objectForKey:@"role"] UTF8String];
@@ -137,7 +140,7 @@ - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chat
   }
 
   auto tmpl = chatTemplate == nil ? "" : [chatTemplate UTF8String];
-  auto formatted_chat = llama_chat_apply_template(llama->model, tmpl, chat, true);
+  auto formatted_chat = common_chat_apply_template(llama->model, tmpl, chat, true);
   return [NSString stringWithUTF8String:formatted_chat.c_str()];
 }
 
@@ -168,21 +171,22 @@ - (NSDictionary *)completion:(NSDictionary *)params
 {
     llama->rewind();
 
-    llama_reset_timings(llama->ctx);
+    //llama_reset_timings(llama->ctx);
 
     NSString *prompt = [params objectForKey:@"prompt"];
 
     llama->params.prompt = [prompt UTF8String];
-    llama->params.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1;
+    llama->params.sparams.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1;
 
     if (params[@"n_threads"]) {
-        int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.n_threads;
+        int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.cpuparams.n_threads;
         const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount];
         // Use 2 threads by default on 4-core devices, 4 threads on more cores
         const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
-        llama->params.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
+        llama->params.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
     }
     if (params[@"n_predict"]) llama->params.n_predict = [params[@"n_predict"] intValue];
+    if (params[@"ignore_eos"]) llama->params.sparams.ignore_eos = [params[@"ignore_eos"] boolValue];
 
     auto & sparams = llama->params.sparams;
 
@@ -203,9 +207,9 @@ - (NSDictionary *)completion:(NSDictionary *)params
     if (params[@"top_k"]) sparams.top_k = [params[@"top_k"] intValue];
     if (params[@"top_p"]) sparams.top_p = [params[@"top_p"] doubleValue];
     if (params[@"min_p"]) sparams.min_p = [params[@"min_p"] doubleValue];
-    if (params[@"tfs_z"]) sparams.tfs_z = [params[@"tfs_z"] doubleValue];
-
-    if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue];
+    if (params[@"xtc_threshold"]) sparams.xtc_threshold = [params[@"xtc_threshold"] doubleValue];
+    if (params[@"xtc_probability"]) sparams.xtc_probability = [params[@"xtc_probability"] doubleValue];
+    if (params[@"typical_p"]) sparams.typ_p = [params[@"typical_p"] doubleValue];
 
     if (params[@"grammar"]) {
         sparams.grammar = [params[@"grammar"] UTF8String];
@@ -221,7 +225,7 @@ - (NSDictionary *)completion:(NSDictionary *)params
 
     sparams.logit_bias.clear();
     if (params[@"ignore_eos"] && [params[@"ignore_eos"] boolValue]) {
-        sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
+        sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
     }
 
     if (params[@"logit_bias"] && [params[@"logit_bias"] isKindOfClass:[NSArray class]]) {
@@ -232,9 +236,9 @@ - (NSDictionary *)completion:(NSDictionary *)params
                 llama_token tok = [el[0] intValue];
                 if (tok >= 0 && tok < n_vocab) {
                     if ([el[1] isKindOfClass:[NSNumber class]]) {
-                        sparams.logit_bias[tok] = [el[1] doubleValue];
+                        sparams.logit_bias[tok].bias = [el[1] doubleValue];
                     } else if ([el[1] isKindOfClass:[NSNumber class]] && ![el[1] boolValue]) {
-                        sparams.logit_bias[tok] = -INFINITY;
+                        sparams.logit_bias[tok].bias = -INFINITY;
                     }
                 }
             }
@@ -255,7 +259,7 @@ - (NSDictionary *)completion:(NSDictionary *)params
         if (token_with_probs.tok == -1 || llama->incomplete) {
             continue;
         }
-        const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
+        const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
 
         size_t pos = std::min(sent_count, llama->generated_text.size());
 
@@ -290,7 +294,7 @@ - (NSDictionary *)completion:(NSDictionary *)params
             tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()];
 
             if (llama->params.sparams.n_probs > 0) {
-                const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false);
+                const std::vector to_send_toks = common_tokenize(llama->ctx, to_send, false);
                 size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
                 size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
                 if (probs_pos < probs_stop_pos) {
@@ -305,10 +309,10 @@ - (NSDictionary *)completion:(NSDictionary *)params
         }
     }
 
-    llama_print_timings(llama->ctx);
+    llama_perf_context_print(llama->ctx);
     llama->is_predicting = false;
 
-    const auto timings = llama_get_timings(llama->ctx);
+    const auto timings = llama_perf_context(llama->ctx);
     return @{
         @"text": [NSString stringWithUTF8String:llama->generated_text.c_str()],
         @"completion_probabilities": [self tokenProbsToDict:llama->generated_token_probs],
@@ -339,7 +343,7 @@ - (void)stopCompletion {
 }
 
 - (NSArray *)tokenize:(NSString *)text {
-    const std::vector toks = llama_tokenize(llama->ctx, [text UTF8String], false);
+    const std::vector toks = common_tokenize(llama->ctx, [text UTF8String], false);
     NSMutableArray *result = [[NSMutableArray alloc] init];
     for (llama_token tok : toks) {
         [result addObject:@(tok)];
@@ -363,7 +367,7 @@ - (NSArray *)embedding:(NSString *)text {
 
     llama->rewind();
 
-    llama_reset_timings(llama->ctx);
+    llama_perf_context_reset(llama->ctx);
 
     llama->params.prompt = [text UTF8String];
 
diff --git a/llama.cpp b/llama.cpp
index 93bc383..a6744e4 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit 93bc3839f980ff14be86efe408b4cd7e89b26835
+Subproject commit a6744e43e80f4be6398fc7733a01642c846dce1d
diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh
index 195427e..92521e9 100755
--- a/scripts/bootstrap.sh
+++ b/scripts/bootstrap.sh
@@ -3,41 +3,49 @@
 git submodule init
 git submodule update --recursive
 
+cp ./llama.cpp/include/llama.h ./cpp/llama.h
+
 cp ./llama.cpp/ggml/include/ggml.h ./cpp/ggml.h
-cp ./llama.cpp/ggml/src/ggml.c ./cpp/ggml.c
+cp ./llama.cpp/ggml/include/ggml-cpp.h ./cpp/ggml-cpp.h
+cp ./llama.cpp/ggml/include/ggml-alloc.h ./cpp/ggml-alloc.h
+cp ./llama.cpp/ggml/include/ggml-backend.h ./cpp/ggml-backend.h
 cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h
+
+cp ./llama.cpp/ggml/src/ggml.c ./cpp/ggml.c
 cp ./llama.cpp/ggml/src/ggml-metal.m ./cpp/ggml-metal.m
-cp ./llama.cpp/ggml/include/ggml-alloc.h ./cpp/ggml-alloc.h
 cp ./llama.cpp/ggml/src/ggml-alloc.c ./cpp/ggml-alloc.c
-cp ./llama.cpp/ggml/include/ggml-backend.h ./cpp/ggml-backend.h
-cp ./llama.cpp/ggml/src/ggml-backend.c ./cpp/ggml-backend.c
+cp ./llama.cpp/ggml/src/ggml-backend.cpp ./cpp/ggml-backend.cpp
 cp ./llama.cpp/ggml/src/ggml-backend-impl.h ./cpp/ggml-backend-impl.h
 cp ./llama.cpp/ggml/src/ggml-impl.h ./cpp/ggml-impl.h
+cp ./llama.cpp/ggml/src/ggml-cpu-impl.h ./cpp/ggml-cpu-impl.h
 cp ./llama.cpp/ggml/src/ggml-common.h ./cpp/ggml-common.h
 cp ./llama.cpp/ggml/src/ggml-quants.h ./cpp/ggml-quants.h
 cp ./llama.cpp/ggml/src/ggml-quants.c ./cpp/ggml-quants.c
+cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c
+cp ./llama.cpp/ggml/src/ggml-aarch64.h ./cpp/ggml-aarch64.h
+
 cp ./llama.cpp/ggml/src/llamafile/sgemm.h ./cpp/sgemm.h
 cp ./llama.cpp/ggml/src/llamafile/sgemm.cpp ./cpp/sgemm.cpp
-cp ./llama.cpp/ggml/src/ggml-aarch64.h ./cpp/ggml-aarch64.h
-cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c
-cp ./llama.cpp/include/llama.h ./cpp/llama.h
+
 cp ./llama.cpp/src/llama.cpp ./cpp/llama.cpp
-cp ./llama.cpp/src/llama-vocab.cpp ./cpp/llama-vocab.cpp
+cp ./llama.cpp/src/llama-impl.h ./cpp/llama-impl.h
+
 cp ./llama.cpp/src/llama-vocab.h ./cpp/llama-vocab.h
-cp ./llama.cpp/src/llama-sampling.cpp ./cpp/llama-sampling.cpp
-cp ./llama.cpp/src/llama-sampling.h ./cpp/llama-sampling.h
-cp ./llama.cpp/src/llama-grammar.cpp ./cpp/llama-grammar.cpp
+cp ./llama.cpp/src/llama-vocab.cpp ./cpp/llama-vocab.cpp
 cp ./llama.cpp/src/llama-grammar.h ./cpp/llama-grammar.h
-cp ./llama.cpp/src/llama-impl.h ./cpp/llama-impl.h
+cp ./llama.cpp/src/llama-grammar.cpp ./cpp/llama-grammar.cpp
+cp ./llama.cpp/src/llama-sampling.h ./cpp/llama-sampling.h
+cp ./llama.cpp/src/llama-sampling.cpp ./cpp/llama-sampling.cpp
+
 cp ./llama.cpp/src/unicode.h ./cpp/unicode.h
 cp ./llama.cpp/src/unicode.cpp ./cpp/unicode.cpp
 cp ./llama.cpp/src/unicode-data.h ./cpp/unicode-data.h
 cp ./llama.cpp/src/unicode-data.cpp ./cpp/unicode-data.cpp
+
 cp ./llama.cpp/common/log.h ./cpp/log.h
+cp ./llama.cpp/common/log.cpp ./cpp/log.cpp
 cp ./llama.cpp/common/common.h ./cpp/common.h
 cp ./llama.cpp/common/common.cpp ./cpp/common.cpp
-cp ./llama.cpp/common/grammar-parser.h ./cpp/grammar-parser.h
-cp ./llama.cpp/common/grammar-parser.cpp ./cpp/grammar-parser.cpp
 cp ./llama.cpp/common/json.hpp ./cpp/json.hpp
 cp ./llama.cpp/common/json-schema-to-grammar.h ./cpp/json-schema-to-grammar.h
 cp ./llama.cpp/common/json-schema-to-grammar.cpp ./cpp/json-schema-to-grammar.cpp
@@ -46,28 +54,36 @@ cp ./llama.cpp/common/sampling.cpp ./cpp/sampling.cpp
 
 # List of files to process
 files=(
+  "./cpp/llama-impl.h"
+  "./cpp/llama-vocab.h"
+  "./cpp/llama-vocab.cpp"
+  "./cpp/llama-grammar.h"
+  "./cpp/llama-grammar.cpp"
+  "./cpp/llama-sampling.h"
+  "./cpp/llama-sampling.cpp"
+  "./cpp/log.h"
+  "./cpp/log.cpp"
   "./cpp/ggml.h"
   "./cpp/ggml.c"
   "./cpp/common.h"
   "./cpp/common.cpp"
+  "./cpp/ggml-cpp.h"
   "./cpp/ggml-metal.h"
   "./cpp/ggml-metal.m"
   "./cpp/llama.h"
   "./cpp/llama.cpp"
-  "./cpp/llama-vocab.cpp"
-  "./cpp/llama-sampling.cpp"
-  "./cpp/llama-grammar.cpp"
-  "./cpp/llama-impl.h"
   "./cpp/sampling.cpp"
   "./cpp/ggml-quants.h"
   "./cpp/ggml-quants.c"
   "./cpp/ggml-alloc.h"
   "./cpp/ggml-alloc.c"
   "./cpp/ggml-backend.h"
-  "./cpp/ggml-backend.c"
+  "./cpp/ggml-backend.cpp"
   "./cpp/ggml-backend-impl.h"
   "./cpp/ggml-impl.h"
+  "./cpp/ggml-cpu-impl.h"
   "./cpp/ggml-common.h"
+  "./cpp/sgemm.h"
   "./cpp/sgemm.cpp"
   "./cpp/json-schema-to-grammar.h"
   "./cpp/ggml-aarch64.h"
@@ -100,7 +116,7 @@ yarn example
 # Apply patch
 patch -p0 -d ./cpp < ./scripts/common.h.patch
 patch -p0 -d ./cpp < ./scripts/common.cpp.patch
-patch -p0 -d ./cpp < ./scripts/log.h.patch
+patch -p0 -d ./cpp < ./scripts/log.cpp.patch
 patch -p0 -d ./cpp < ./scripts/llama.cpp.patch
 patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch
 patch -p0 -d ./cpp < ./scripts/ggml.c.patch
diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch
index 3c793cd..ae22c09 100644
--- a/scripts/common.cpp.patch
+++ b/scripts/common.cpp.patch
@@ -1,15 +1,24 @@
---- common.cpp.orig
-+++ common.cpp
-@@ -52,6 +52,12 @@
+--- common.cpp.orig	2024-11-02 10:33:10
++++ common.cpp	2024-11-02 10:33:11
+@@ -53,6 +53,12 @@
+ #include 
  #include 
  #endif
-
++
 +// build info
 +int LLAMA_BUILD_NUMBER = 0;
 +char const *LLAMA_COMMIT = "unknown";
 +char const *LLAMA_COMPILER = "unknown";
 +char const *LLAMA_BUILD_TARGET = "unknown";
-+
+
  #if defined(_MSC_VER)
  #pragma warning(disable: 4244 4267) // possible loss of data
- #endif
+@@ -979,6 +985,8 @@
+     if (params.n_gpu_layers != -1) {
+         mparams.n_gpu_layers = params.n_gpu_layers;
+     }
++
++    mparams.vocab_only      = params.vocab_only;
+     mparams.rpc_servers     = params.rpc_servers.c_str();
+     mparams.main_gpu        = params.main_gpu;
+     mparams.split_mode      = params.split_mode;
diff --git a/scripts/common.h.patch b/scripts/common.h.patch
index cd020e7..354d31d 100644
--- a/scripts/common.h.patch
+++ b/scripts/common.h.patch
@@ -1,9 +1,10 @@
---- common.h.orig	2024-08-23 14:12:33
-+++ common.h	2024-08-23 14:12:34
-@@ -50,6 +50,17 @@
-
- struct llama_control_vector_load_info;
+--- common.h.orig	2024-11-02 10:33:10
++++ common.h	2024-11-02 10:33:11
+@@ -40,6 +40,17 @@
+ extern char const * LLAMA_BUILD_TARGET;
 
+ struct common_control_vector_load_info;
++
 +#define print_build_info() do {                                                                     \
 +    fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);           \
 +    fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);    \
@@ -14,7 +15,14 @@
 +extern char const *LLAMA_COMMIT;
 +extern char const *LLAMA_COMPILER;
 +extern char const *LLAMA_BUILD_TARGET;
-+
+
  //
  // CPU utils
- //
+@@ -154,6 +165,7 @@
+ };
+
+ struct common_params {
++    bool vocab_only               = false;
+     int32_t n_predict             =    -1; // new tokens to predict
+     int32_t n_ctx                 =     0; // context size
+     int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch
index a1266ac..4f3821a 100644
--- a/scripts/ggml-metal.m.patch
+++ b/scripts/ggml-metal.m.patch
@@ -1,9 +1,9 @@
---- ggml-metal.m.orig	2024-08-23 14:12:33
-+++ ggml-metal.m	2024-08-23 14:12:34
-@@ -340,7 +340,7 @@
+--- ggml-metal.m.orig
++++ ggml-metal.m
+@@ -389,7 +389,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
          const bool try_metallib = true;
  #endif
-
+ 
 -        NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
 +        NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
          if (try_metallib && path_lib != nil) {
diff --git a/scripts/ggml.c.patch b/scripts/ggml.c.patch
index 984f9a6..25744ff 100644
--- a/scripts/ggml.c.patch
+++ b/scripts/ggml.c.patch
@@ -1,6 +1,6 @@
---- ggml.c.orig	2024-08-23 14:12:33
-+++ ggml.c	2024-08-23 14:12:34
-@@ -194,9 +194,9 @@
+--- ggml.c.orig
++++ ggml.c
+@@ -242,9 +242,9 @@ static void lm_ggml_print_backtrace_symbols(void) {
  #elif defined(__linux__) && defined(__GLIBC__)
  #include 
  static void lm_ggml_print_backtrace_symbols(void) {
diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch
index f5d26b8..b4c1571 100644
--- a/scripts/llama.cpp.patch
+++ b/scripts/llama.cpp.patch
@@ -1,6 +1,6 @@
---- llama.cpp.orig	2024-08-23 14:12:33
-+++ llama.cpp	2024-08-23 14:12:34
-@@ -104,6 +104,17 @@
+--- llama.cpp.orig	2024-11-02 11:13:58
++++ llama.cpp	2024-11-02 11:19:21
+@@ -80,6 +80,17 @@
  #define LLAMA_MAX_LAYERS  512
  #define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
 
@@ -18,7 +18,7 @@
  //
  // helpers
  //
-@@ -1741,16 +1752,16 @@
+@@ -1930,16 +1941,16 @@
 
          if (prefetch > 0) {
              // advise the kernel to preload the mapped memory
@@ -39,3 +39,28 @@
                          strerror(errno));
              }
          }
+@@ -19086,7 +19097,9 @@
+
+ #ifdef GGML_USE_METAL
+     // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
+-    result.n_gpu_layers = 999;
++    if (result.n_gpu_layers > 0) {
++        result.n_gpu_layers = 999;
++    }
+ #endif
+
+     return result;
+@@ -19289,7 +19302,13 @@
+                 break;
+
+             case LM_GGML_BACKEND_DEVICE_TYPE_GPU:
++#ifdef LM_GGML_USE_METAL
++                if (params.n_gpu_layers > 0) {
++                    model->devices.push_back(dev);
++                }
++#else
+                 model->devices.push_back(dev);
++#endif
+                 break;
+         }
+     }
diff --git a/scripts/log.cpp.patch b/scripts/log.cpp.patch
new file mode 100644
index 0000000..3fcfcce
--- /dev/null
+++ b/scripts/log.cpp.patch
@@ -0,0 +1,54 @@
+--- log.cpp.orig
++++ log.cpp
+@@ -1,3 +1,6 @@
++#if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
++#include 
++#endif
+ #include "log.h"
+ 
+ #include 
+@@ -66,7 +69,36 @@ struct common_log_entry {
+     // signals the worker thread to stop
+     bool is_end;
+ 
++    #if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
++    void android_print() const {
++        int android_log_priority;
++        switch (level) {
++            case LM_GGML_LOG_LEVEL_INFO:
++                android_log_priority = ANDROID_LOG_INFO;
++                break;
++            case LM_GGML_LOG_LEVEL_WARN:
++                android_log_priority = ANDROID_LOG_WARN;
++                break;
++            case LM_GGML_LOG_LEVEL_ERROR:
++                android_log_priority = ANDROID_LOG_ERROR;
++                break;
++            case LM_GGML_LOG_LEVEL_DEBUG:
++                android_log_priority = ANDROID_LOG_DEBUG;
++                break;
++            default:
++                android_log_priority = ANDROID_LOG_DEFAULT;
++                break;
++        }
++
++        const char * tag = "RNLLAMA_LOG_ANDROID"; 
++        __android_log_print(android_log_priority, tag, "%s", msg.data());
++    }
++    #endif
++
+     void print(FILE * file = nullptr) const {
++        #if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
++        android_print();
++        #else
+         FILE * fcur = file;
+         if (!fcur) {
+             // stderr displays DBG messages only when their verbosity level is not higher than the threshold
+@@ -111,6 +143,7 @@ struct common_log_entry {
+         }
+ 
+         fflush(fcur);
++        #endif
+     }
+ };
+ 
diff --git a/scripts/log.h.patch b/scripts/log.h.patch
deleted file mode 100644
index 82b3753..0000000
--- a/scripts/log.h.patch
+++ /dev/null
@@ -1,22 +0,0 @@
---- log.h.orig
-+++ log.h
-@@ -323,6 +323,19 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std::
-     #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", ##__VA_ARGS__, "\n")
- #endif
-
-+#if defined(__ANDROID__) && defined(RNLLAMA_ANDROID_ENABLE_LOGGING)
-+#include 
-+#define LLAMA_ANDROID_LOG_TAG "RNLLAMA_LOG_ANDROID"
-+#undef LOG
-+#undef LOG_TEE
-+#undef LOGLN
-+#undef LOG_TEELN
-+#define LOG(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-+#define LOG_TEE(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-+#define LOGLN(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-+#define LOG_TEELN(...) __android_log_print(ANDROID_LOG_INFO, LLAMA_ANDROID_LOG_TAG, __VA_ARGS__)
-+#endif
-+
- // INTERNAL, DO NOT USE
- inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr)
- {
diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts
index f2d6882..d42d3fb 100644
--- a/src/NativeRNLlama.ts
+++ b/src/NativeRNLlama.ts
@@ -15,6 +15,7 @@ export type NativeContextParams = {
 
   use_mlock?: boolean
   use_mmap?: boolean
+  vocab_only?: boolean
 
   lora?: string // lora_adaptor
   lora_scaled?: number
@@ -34,7 +35,8 @@ export type NativeCompletionParams = {
   top_k?: number
   top_p?: number
   min_p?: number
-  tfs_z?: number
+  xtc_threshold?: number
+  xtc_probability?: number
   typical_p?: number
   temperature?: number // -> temp
   penalty_last_n?: number