diff --git a/cpp/common.cpp b/cpp/common.cpp index 63aed37..987a93f 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -73,7 +73,6 @@ char const *LLAMA_BUILD_TARGET = "unknown"; #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 -#define LLAMA_CURL_MAX_HEADER_LENGTH 256 #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -240,8 +239,54 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return result; } +bool parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { - llama_sampling_params& sparams = params.sparams; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { if (++i >= argc) { @@ -853,7 +898,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - params.image = argv[i]; + params.image.emplace_back(argv[i]); return true; } if (arg == "-i" || arg == "--interactive") { @@ -908,6 +953,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1095,6 +1144,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_print = std::stoi(argv[i]); return true; } + if (arg == "--check-tensors") { + params.check_tensors = true; + return true; + } if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -1246,47 +1299,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - char* sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } - else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } - else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } - else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } - else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - } - else { + if (!parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; } - params.kv_overrides.push_back(kvo); return true; } #ifndef LOG_DISABLE_LOGS @@ -1316,6 +1333,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } +void gpt_params_handle_model_default(gpt_params & params) { + if (!params.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (params.hf_file.empty()) { + if (params.model.empty()) { + throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); + } + params.hf_file = params.model; + } else if (params.model.empty()) { + params.model = "models/" + string_split(params.hf_file, '/').back(); + } + } else if (!params.model_url.empty()) { + if (params.model.empty()) { + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + f = string_split(f, '/').back(); + params.model = "models/" + f; + } + } else if (params.model.empty()) { + params.model = DEFAULT_MODEL_PATH; + } +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -1344,10 +1384,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // short-hand to avoid specifying --hf-file -> default it to --model - if (!params.hf_repo.empty() && params.hf_file.empty()) { - params.hf_file = params.model; - } + gpt_params_handle_model_default(params); if (params.escape) { process_escapes(params.prompt); @@ -1486,8 +1523,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); + printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -1540,7 +1578,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --control-vector-layer-range START END\n"); printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: %s)\n", params.model.c_str()); + printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH); printf(" -md FNAME, --model-draft FNAME\n"); printf(" draft model for speculative decoding (default: unused)\n"); printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); @@ -1557,9 +1595,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -ptc N, --print-token-count N\n"); printf(" print token count every N tokens (default: %d)\n", params.n_print); + printf(" --check-tensors check model tensor data for invalid values\n"); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); @@ -1684,6 +1723,18 @@ std::vector string_split(std::string input, char separator) { return parts; } +std::string string_strip(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && std::isspace(str[start])) { + start++; + } + while (end > start && std::isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { {"top_k", llama_sampler_type::TOP_K}, @@ -1780,6 +1831,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -1844,6 +1896,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -1874,59 +1927,75 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL -static bool llama_download_file(CURL * curl, const char * url, const char * path) { +static bool starts_with(const std::string & str, const std::string & prefix) { + // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + +static bool llama_download_file(const std::string & url, const std::string & path) { + + // Initialize libcurl + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + if (!curl) { + fprintf(stderr, "%s: error initializing libcurl\n", __func__); + return false; + } + bool force_download = false; // Set the URL, allow to follow http redirection - curl_easy_setopt(curl, CURLOPT_URL, url); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif // Check if the file already exists locally struct stat model_file_info; - auto file_exists = (stat(path, &model_file_info) == 0); - - // If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files - char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char etag_path[PATH_MAX] = {0}; - snprintf(etag_path, sizeof(etag_path), "%s.etag", path); + auto file_exists = (stat(path.c_str(), &model_file_info) == 0); - char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char last_modified_path[PATH_MAX] = {0}; - snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path); + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; + std::string etag; + std::string last_modified; if (file_exists) { - auto * f_etag = fopen(etag_path, "r"); - if (f_etag) { - if (!fgets(etag, sizeof(etag), f_etag)) { - fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path); - } else { - fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag); - } - fclose(f_etag); - } - - auto * f_last_modified = fopen(last_modified_path, "r"); - if (f_last_modified) { - if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) { - fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path); - } else { - fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path, - last_modified); + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("url") && metadata["url"].is_string()) { + auto previous_url = metadata["url"].get(); + if (previous_url != url) { + fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + return false; + } + } + if (metadata.contains("etag") && metadata["etag"].is_string()) { + etag = metadata["etag"]; + } + if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) { + last_modified = metadata["lastModified"]; + } + } catch (const nlohmann::json::exception & e) { + fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + return false; } - fclose(f_last_modified); } + } else { + fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str()); } // Send a HEAD request to retrieve the etag and last-modified headers struct llama_load_model_from_url_headers { - char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + std::string etag; + std::string last_modified; }; llama_load_model_from_url_headers headers; { @@ -1934,38 +2003,37 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; - // Convert header field name to lowercase - for (size_t i = 0; i < n_items && buffer[i] != ':'; ++i) { - buffer[i] = tolower(buffer[i]); - } - - const char * etag_prefix = "etag: "; - if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { - strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF - } - - const char * last_modified_prefix = "last-modified: "; - if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) { - strncpy(headers->last_modified, buffer + strlen(last_modified_prefix), - n_items - strlen(last_modified_prefix) - 2); // Remove CRLF + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } } return n_items; }; - curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - CURLcode res = curl_easy_perform(curl); + CURLcode res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - curl_easy_cleanup(curl); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); return false; } long http_code = 0; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code != 200) { // HEAD not supported, we don't know if the file has changed // force trigger downloading @@ -1974,28 +2042,30 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path } } - // If the ETag or the Last-Modified headers are different: trigger a new download - bool should_download = !file_exists - || force_download - || (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0) - || (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0); + bool should_download = !file_exists || force_download; + if (!should_download) { + if (!etag.empty() && etag != headers.etag) { + fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } if (should_download) { - char path_temporary[PATH_MAX] = {0}; - snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path); + std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { - fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path); - if (remove(path) != 0) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path); + fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str()); return false; } } // Set the output file - auto * outfile = fopen(path_temporary, "wb"); + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb"), fclose); if (!outfile) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path); + fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; } @@ -2003,12 +2073,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { return fwrite(data, size, nmemb, (FILE *)fd); }; - curl_easy_setopt(curl, CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); // display download progress - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); // helper function to hide password in URL auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { @@ -2027,51 +2097,34 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path // start the download fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified); - auto res = curl_easy_perform(curl); + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + auto res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - fclose(outfile); - curl_easy_cleanup(curl); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); return false; } long http_code = 0; - curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code < 200 || http_code >= 400) { - fclose(outfile); - curl_easy_cleanup(curl); fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); return false; } - // Clean up - fclose(outfile); + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); - // Write the new ETag to the .etag file - if (strlen(headers.etag) > 0) { - auto * etag_file = fopen(etag_path, "w"); - if (etag_file) { - fputs(headers.etag, etag_file); - fclose(etag_file); - fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); - } - } - - // Write the new lastModified to the .etag file - if (strlen(headers.last_modified) > 0) { - auto * last_modified_file = fopen(last_modified_path, "w"); - if (last_modified_file) { - fputs(headers.last_modified, last_modified_file); - fclose(last_modified_file); - fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path, - headers.last_modified); - } - } + // Write the updated JSON metadata file. + metadata.update({ + {"url", url}, + {"etag", headers.etag}, + {"lastModified", headers.last_modified} + }); + std::ofstream(metadata_path) << metadata.dump(4); + fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - if (rename(path_temporary, path) != 0) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path); + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } } @@ -2089,15 +2142,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - // Initialize libcurl - auto * curl = curl_easy_init(); - - if (!curl) { - fprintf(stderr, "%s: error initializing libcurl\n", __func__); - return NULL; - } - - if (!llama_download_file(curl, model_url, path_model)) { + if (!llama_download_file(model_url, path_model)) { return NULL; } @@ -2111,7 +2156,6 @@ struct llama_model * llama_load_model_from_url( auto * ctx_gguf = lm_gguf_init_from_file(path_model, lm_gguf_params); if (!ctx_gguf) { fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model); - curl_easy_cleanup(curl); return NULL; } @@ -2123,8 +2167,6 @@ struct llama_model * llama_load_model_from_url( lm_gguf_free(ctx_gguf); } - curl_easy_cleanup(curl); - if (n_split > 1) { char split_prefix[PATH_MAX] = {0}; char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; @@ -2155,11 +2197,7 @@ struct llama_model * llama_load_model_from_url( char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - auto * curl = curl_easy_init(); - bool res = llama_download_file(curl, split_url, split_path); - curl_easy_cleanup(curl); - - return res; + return llama_download_file(split_url, split_path); }, idx)); } @@ -2646,7 +2684,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); + fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); @@ -2681,6 +2719,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/cpp/common.h b/cpp/common.h index e4c971e..e326cb1 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -31,6 +31,8 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) +#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" + // build info extern int LLAMA_BUILD_NUMBER; extern char const *LLAMA_COMMIT; @@ -103,7 +105,7 @@ struct gpt_params { // // sampling parameters struct llama_sampling_params sparams; - std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download @@ -144,7 +146,7 @@ struct gpt_params { bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - bool kl_divergence = false; // compute KL-divergence + bool kl_divergence = false; // compute KL divergence bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs @@ -159,6 +161,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens @@ -172,15 +175,20 @@ struct gpt_params { bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file + std::string mmproj = ""; // path to multimodal projector + std::vector image; // path to image file(s) }; +void gpt_params_handle_model_default(gpt_params & params); + +bool parse_kv_override(const char * data, std::vector & overrides); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params); @@ -204,6 +212,7 @@ bool validate_file_name(const std::string & filename); std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector sampler_types_from_chars(const std::string & names_string); std::vector string_split(std::string input, char separator); +std::string string_strip(const std::string & str); std::string sampler_type_to_name_string(llama_sampler_type sampler_type); // diff --git a/cpp/ggml-backend.c b/cpp/ggml-backend.c index ed53f8b..744fd39 100644 --- a/cpp/ggml-backend.c +++ b/cpp/ggml-backend.c @@ -1784,12 +1784,14 @@ void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched) { void lm_ggml_backend_sched_reset(lm_ggml_backend_sched_t sched) { // reset state for the next run - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT - memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); - memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); + if (!sched->is_reset) { + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT + memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); + memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); - sched->is_reset = true; + sched->is_reset = true; + } sched->is_alloc = false; } diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index e821630..53315a6 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -11,6 +11,12 @@ #include // memcpy #include // fabsf +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #ifdef __cplusplus extern "C" { #endif @@ -307,7 +313,7 @@ inline static int32x4_t lm_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t #endif // defined(__ARM_NEON) -#if defined(__ARM_NEON) && !defined(__MSC_VER) +#if defined(__ARM_NEON) && !defined(_MSC_VER) #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index f1a393b..502c61b 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -46,8 +46,10 @@ LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, LM_GGML_METAL_KERNEL_TYPE_SILU, LM_GGML_METAL_KERNEL_TYPE_SILU_4, - LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX, - LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -177,6 +179,14 @@ LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -443,7 +453,7 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, } /* - LM_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + LM_GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -459,172 +469,182 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, return NULL; \ } \ } else { \ - LM_GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + LM_GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -743,6 +763,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, case LM_GGML_OP_TIMESTEP_EMBEDDING: case LM_GGML_OP_ARGSORT: case LM_GGML_OP_LEAKY_RELU: + case LM_GGML_OP_FLASH_ATTN_EXT: return true; case LM_GGML_OP_MUL_MAT: case LM_GGML_OP_MUL_MAT_ID: @@ -1326,20 +1347,33 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( } break; case LM_GGML_OP_SOFT_MAX: { + LM_GGML_ASSERT(!src1 || src1->type == LM_GGML_TYPE_F16 || src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(!src2 || src2->type == LM_GGML_TYPE_F16 || src2->type == LM_GGML_TYPE_F32); + int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16) || (src2 && src2->type == LM_GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; @@ -2503,6 +2537,161 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + struct lm_ggml_tensor * src3 = gf->nodes[i]->src[3]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src1, src2)); + LM_GGML_ASSERT(src3); + + size_t offs_src3 = 0; + + id id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil; + + LM_GGML_ASSERT(!src3 || src3->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(!src3 || src3->ne[1] >= LM_GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; LM_GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; LM_GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; LM_GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; LM_GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; LM_GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; LM_GGML_UNUSED(nb33); + + const enum lm_ggml_type src2t = src2 ? src2->type : LM_GGML_TYPE_COUNT; LM_GGML_UNUSED(src2t); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ASSERT(false && "add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + LM_GGML_ASSERT(nqptg <= 32); + LM_GGML_ASSERT(nqptg % 8 == 0); + LM_GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + LM_GGML_ASSERT(nqptg <= 32); + LM_GGML_ASSERT(nqptg % 1 == 0); + LM_GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; case LM_GGML_OP_DUP: case LM_GGML_OP_CPY: case LM_GGML_OP_CONT: @@ -2590,6 +2779,11 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { LM_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + NSString * error_code = [command_buffer error].localizedDescription; + LM_GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); + } + return LM_GGML_STATUS_FAILED; } } @@ -2706,10 +2900,13 @@ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buff UNUSED(buft); } -static void lm_ggml_backend_metal_log_allocated_size(id device) { +static void lm_ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef LM_GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - LM_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2719,10 +2916,15 @@ static void lm_ggml_backend_metal_log_allocated_size(id device) { LM_GGML_METAL_LOG_INFO("\n"); } } else { - LM_GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { @@ -2756,8 +2958,7 @@ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_a return NULL; } - LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - lm_ggml_backend_metal_log_allocated_size(device); + //lm_ggml_backend_metal_log_allocated_size(device, size_aligned); return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_metal_buffer_i, ctx, size); } @@ -2844,7 +3045,7 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void return false; } - LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + lm_ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2867,7 +3068,8 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void return false; } - LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + lm_ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { LM_GGML_METAL_LOG_INFO("\n"); } @@ -2876,8 +3078,6 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void } } - lm_ggml_backend_metal_log_allocated_size(device); - return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size); } diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 105322a..69f30bd 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -14,12 +14,6 @@ #include // for qsort #include // for LM_GGML_ASSERT -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - #define UNUSED LM_GGML_UNUSED // some compilers don't provide _mm256_set_m128i, e.g. gcc 7 @@ -12389,3 +12383,287 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) block_iq2_s * restrict y = vy; quantize_row_iq2_s_reference(x, y, k); } + +static bool validate_float(float f, size_t i) { + if (isinf(f)) { + fprintf(stderr, "lm_ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan(f)) { + fprintf(stderr, "lm_ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +static bool isinf_fp16(lm_ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0; +} + +static bool isnan_fp16(lm_ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; +} + +static bool validate_fp16(lm_ggml_fp16_t f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "lm_ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "lm_ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i)) { \ + return false; \ + } \ + } + +#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \ + return false; \ + } \ + } + +bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t nbytes) { + if (type < 0 || type >= LM_GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + + if (nbytes % lm_ggml_type_size(type) != 0) { + fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type); + return false; + } + + const size_t nb = nbytes/lm_ggml_type_size(type); + + switch (type) { + case LM_GGML_TYPE_F16: + { + const lm_ggml_fp16_t * f = (const lm_ggml_fp16_t *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 15 < nb; i += 16) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); + __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 16; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + LM_GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 7 < nb; i += 8) { + uint16x8_t v = vld1q_u16(f + i); + uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00)); + uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + LM_GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_fp16(f[i], i)) { + return false; + } + } + } break; + case LM_GGML_TYPE_F32: + { + const float * f = (const float *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 7 < nb; i += 8) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); + __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + LM_GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 3 < nb; i += 4) { + uint32x4_t v = vld1q_u32((const uint32_t *)f + i); + uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000)); + uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0); + if (mask) { + for (size_t j = 0; j < 4; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + LM_GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case LM_GGML_TYPE_F64: + { + const double * f = (const double *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case LM_GGML_TYPE_Q4_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); + } break; + case LM_GGML_TYPE_Q4_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m); + } break; + case LM_GGML_TYPE_Q5_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb); + } break; + case LM_GGML_TYPE_Q5_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m); + } break; + case LM_GGML_TYPE_Q8_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); + } break; + case LM_GGML_TYPE_Q2_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); + } break; + case LM_GGML_TYPE_Q3_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb); + } break; + case LM_GGML_TYPE_Q4_K: + { + #ifdef LM_GGML_QKK_64 + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin); + #endif + } break; + case LM_GGML_TYPE_Q5_K: + { + #ifdef LM_GGML_QKK_64 + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin); + #endif + } break; + case LM_GGML_TYPE_Q6_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb); + } break; + case LM_GGML_TYPE_Q8_K: + { + const block_q8_K * q = (const block_q8_K *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(q[i].d, i)) { + return false; + } + } + } break; + case LM_GGML_TYPE_IQ1_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); + } break; + case LM_GGML_TYPE_IQ1_M: + { + const block_iq1_m * q = (const block_iq1_m *) data; + for (size_t i = 0; i < nb; ++i) { + #if QK_K == 64 + if (!validate_fp16(q[i].d, i)) { + return false; + } + #else + iq1m_scale_t scale; + const uint16_t * sc = (const uint16_t *)q[i].scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + if (!validate_fp16(scale.f16, i)) { + return false; + } + #endif + } + } break; + case LM_GGML_TYPE_IQ2_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb); + } break; + case LM_GGML_TYPE_IQ2_XS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb); + } break; + case LM_GGML_TYPE_IQ2_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb); + } break; + case LM_GGML_TYPE_IQ3_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb); + } break; + + case LM_GGML_TYPE_IQ3_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb); + } break; + case LM_GGML_TYPE_IQ4_XS: + #if QK_K != 64 + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb); + } break; + #endif + // with QK_K == 64, iq4_xs is iq4_nl + case LM_GGML_TYPE_IQ4_NL: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); + } break; + case LM_GGML_TYPE_I8: + case LM_GGML_TYPE_I16: + case LM_GGML_TYPE_I32: + case LM_GGML_TYPE_I64: + // nothing to validate + break; + default: + { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + } + + return true; +} diff --git a/cpp/ggml.c b/cpp/ggml.c index 51f91c6..9e656d9 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -858,18 +858,6 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) { // simd mappings // -#if defined(__ARM_NEON) -#if !defined(__aarch64__) - -// 64-bit compatibility - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -#endif -#endif - // we define a common set of C macros which map to specific intrinsics based on the current architecture // we then implement the fundamental computation operations below using only these macros // adding support for new architectures requires to define the corresponding SIMD macros @@ -963,7 +951,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1 #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p) - #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE(p, r[i]) + #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), r[i]) #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL @@ -989,7 +977,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO #define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) - #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE(p, r[i]) + #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE((lm_ggml_fp16_internal_t *)(p), r[i]) #define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA #define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD #define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL @@ -1058,7 +1046,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define LM_GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define LM_GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define LM_GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define LM_GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) @@ -1156,7 +1144,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define LM_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { @@ -1674,6 +1662,37 @@ inline static void lm_ggml_vec_mad_f32(const int n, float * restrict y, const fl #endif } +inline static void lm_ggml_vec_mad_f16(const int n, lm_ggml_fp16_t * restrict y, const lm_ggml_fp16_t * restrict x, const float v) { +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); + + LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void lm_ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1758,6 +1777,35 @@ inline static void lm_ggml_vec_scale_f32(const int n, float * y, const float v #endif } +inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const float v) { +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); + + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_MUL(ay[j], vx); + + LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -2012,6 +2060,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", @@ -2038,7 +2087,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(LM_GGML_OP_COUNT == 76, "LM_GGML_OP_COUNT != 76"); +static_assert(LM_GGML_OP_COUNT == 77, "LM_GGML_OP_COUNT != 77"); static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "none", @@ -2102,6 +2151,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -2128,7 +2178,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(LM_GGML_OP_COUNT == 76, "LM_GGML_OP_COUNT != 76"); +static_assert(LM_GGML_OP_COUNT == 77, "LM_GGML_OP_COUNT != 77"); static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); @@ -4571,6 +4621,8 @@ struct lm_ggml_tensor * lm_ggml_mul_mat( void lm_ggml_mul_mat_set_prec( struct lm_ggml_tensor * a, enum lm_ggml_prec prec) { + LM_GGML_ASSERT(a->op == LM_GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; lm_ggml_set_op_params_i32(a, 0, prec_i32); @@ -5409,17 +5461,23 @@ static struct lm_ggml_tensor * lm_ggml_soft_max_impl( LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); if (mask) { + LM_GGML_ASSERT(mask->type == LM_GGML_TYPE_F16 || mask->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); LM_GGML_ASSERT(lm_ggml_is_matrix(mask)); - LM_GGML_ASSERT(lm_ggml_can_repeat_rows(mask, a)); + LM_GGML_ASSERT(mask->ne[0] == a->ne[0]); + LM_GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { LM_GGML_ASSERT(lm_ggml_is_vector(pos)); - LM_GGML_ASSERT(pos->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(pos->type == LM_GGML_TYPE_F16 || pos->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + LM_GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { LM_GGML_ASSERT(pos); } @@ -6228,6 +6286,59 @@ struct lm_ggml_tensor * lm_ggml_flash_attn( return result; } +// lm_ggml_flash_attn_ext + +struct lm_ggml_tensor * lm_ggml_flash_attn_ext( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * q, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * mask, + float scale) { + LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); + LM_GGML_ASSERT(mask->ne[2] == 1); + LM_GGML_ASSERT(mask->ne[3] == 1); + LM_GGML_ASSERT(mask->ne[1] >= LM_GGML_PAD(q->ne[1], LM_GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to LM_GGML_KQ_MASK_PAD and at least n_queries big"); + //LM_GGML_ASSERT(lm_ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + float params[] = { scale }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void lm_ggml_flash_attn_ext_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec) { + LM_GGML_ASSERT(a->op == LM_GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + lm_ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // lm_ggml_flash_ff struct lm_ggml_tensor * lm_ggml_flash_ff( @@ -12267,7 +12378,7 @@ static void lm_ggml_compute_forward_soft_max_f32( LM_GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -12290,19 +12401,31 @@ static void lm_ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + lm_ggml_fp16_t * pos_f16 = src2 ? (lm_ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16) || (src2 && src2->type == LM_GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + lm_ggml_fp16_t * mp_f16 = src1 ? (lm_ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; lm_ggml_vec_cpy_f32 (nc, wp, sp); lm_ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - lm_ggml_vec_acc_f32(nc, wp, mp); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += LM_GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } + } } // ALiBi bias @@ -12310,8 +12433,14 @@ static void lm_ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*LM_GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } @@ -14581,6 +14710,198 @@ static void lm_ggml_compute_forward_flash_attn( } } +// lm_ggml_compute_forward_flash_attn_ext + +static void lm_ggml_compute_forward_flash_attn_ext_f16( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * q, + const struct lm_ggml_tensor * k, + const struct lm_ggml_tensor * v, + const struct lm_ggml_tensor * mask, + struct lm_ggml_tensor * dst) { + int64_t t0 = lm_ggml_perf_time_us(); + UNUSED(t0); + + LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + LM_GGML_ASSERT(ne0 == D); + LM_GGML_ASSERT(ne2 == N); + + LM_GGML_ASSERT(nbq0 == sizeof(float)); + LM_GGML_ASSERT(nbk0 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nbv0 == sizeof(lm_ggml_fp16_t)); + + LM_GGML_ASSERT(neq0 == D); + LM_GGML_ASSERT(nek0 == D); + LM_GGML_ASSERT(nev0 == D); + + LM_GGML_ASSERT(neq1 == N); + LM_GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == LM_GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using lm_ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + lm_ggml_fp16_t * Q16 = (lm_ggml_fp16_t *) (V32); // reuse memory + lm_ggml_fp16_t * V16 = (lm_ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(lm_ggml_fp16_t)); + + const lm_ggml_fp16_t * mp = mask ? (lm_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? LM_GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = LM_GGML_FP32_TO_FP16(pq[d]); + } + } + + lm_ggml_vec_dot_f16(D, + &s, 0, + (lm_ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + lm_ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const lm_ggml_fp16_t * v16 = (const lm_ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + lm_ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = LM_GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void lm_ggml_compute_forward_flash_attn_ext( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * q, + const struct lm_ggml_tensor * k, + const struct lm_ggml_tensor * v, + const struct lm_ggml_tensor * mask, + struct lm_ggml_tensor * dst) { + switch (dst->op_params[1]) { + case LM_GGML_PREC_DEFAULT: + case LM_GGML_PREC_F32: + { + // uses F32 accumulators + lm_ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + LM_GGML_ASSERT(false); + } break; + } +} + // lm_ggml_compute_forward_flash_ff static void lm_ggml_compute_forward_flash_ff_f16( @@ -16388,6 +16709,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru const bool masked = t != 0; lm_ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + lm_ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case LM_GGML_OP_FLASH_FF: { lm_ggml_compute_forward_flash_ff(params, tensor); @@ -17400,6 +17725,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm LM_GGML_ASSERT(false); // TODO: not implemented } break; case LM_GGML_OP_FLASH_ATTN: + case LM_GGML_OP_FLASH_ATTN_EXT: { struct lm_ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -18172,6 +18498,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int n_tasks = n_threads; } break; case LM_GGML_OP_FLASH_ATTN: + case LM_GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -18575,6 +18902,12 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case LM_GGML_OP_FLASH_FF: { if (node->src[1]->type == LM_GGML_TYPE_F32) { @@ -20626,7 +20959,7 @@ static void lm_gguf_free_kv(struct lm_gguf_kv * kv) { } struct lm_gguf_context * lm_gguf_init_empty(void) { - struct lm_gguf_context * ctx = LM_GGML_ALIGNED_MALLOC(sizeof(struct lm_gguf_context)); + struct lm_gguf_context * ctx = LM_GGML_CALLOC(1, sizeof(struct lm_gguf_context)); memcpy(ctx->header.magic, LM_GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = LM_GGUF_VERSION; @@ -20671,7 +21004,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg bool ok = true; - struct lm_gguf_context * ctx = LM_GGML_ALIGNED_MALLOC(sizeof(struct lm_gguf_context)); + struct lm_gguf_context * ctx = LM_GGML_CALLOC(1, sizeof(struct lm_gguf_context)); // read the header { @@ -20708,9 +21041,13 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg // read the kv pairs { - ctx->kv = LM_GGML_MALLOC(ctx->header.n_kv * sizeof(struct lm_gguf_kv)); + const uint64_t n_kv = ctx->header.n_kv; - for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { + // header.n_kv will hold the actual value of pairs that were successfully read in the loop below + ctx->header.n_kv = 0; + ctx->kv = LM_GGML_CALLOC(n_kv, sizeof(struct lm_gguf_kv)); + + for (uint64_t i = 0; i < n_kv; ++i) { struct lm_gguf_kv * kv = &ctx->kv[i]; //fprintf(stderr, "%s: reading kv %d\n", __func__, i); @@ -20759,7 +21096,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg return NULL; } - kv->value.arr.data = LM_GGML_MALLOC(kv->value.arr.n * lm_gguf_type_size(kv->value.arr.type)); + kv->value.arr.data = LM_GGML_CALLOC(kv->value.arr.n, lm_gguf_type_size(kv->value.arr.type)); ok = ok && lm_gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * lm_gguf_type_size(kv->value.arr.type), &offset); } break; @@ -20773,7 +21110,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg return NULL; } - kv->value.arr.data = LM_GGML_MALLOC(kv->value.arr.n * sizeof(struct lm_gguf_str)); + kv->value.arr.data = LM_GGML_CALLOC(kv->value.arr.n, sizeof(struct lm_gguf_str)); for (uint64_t j = 0; j < kv->value.arr.n; ++j) { ok = ok && lm_gguf_fread_str(file, &((struct lm_gguf_str *) kv->value.arr.data)[j], &offset); @@ -20789,6 +21126,8 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { break; } + + ctx->header.n_kv++; } if (!ok) { @@ -20801,7 +21140,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg // read the tensor infos { - ctx->infos = LM_GGML_MALLOC(ctx->header.n_tensors * sizeof(struct lm_gguf_tensor_info)); + ctx->infos = LM_GGML_CALLOC(ctx->header.n_tensors, sizeof(struct lm_gguf_tensor_info)); for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { struct lm_gguf_tensor_info * info = &ctx->infos[i]; @@ -20822,8 +21161,17 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg ok = ok && lm_gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && lm_gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + // TODO: return an error instead of crashing with LM_GGML_ASSERT lm_gguf_tensor_info_sanitize(info); + // make sure there is no duplicated tensor names + for (uint64_t j = 0; j < i; ++j) { + if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { + fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); + ok = false; + } + } + if (!ok) { fprintf(stderr, "%s: failed to read tensor info\n", __func__); fclose(file); @@ -20992,7 +21340,7 @@ void lm_gguf_free(struct lm_gguf_context * ctx) { LM_GGML_FREE(ctx->infos); } - LM_GGML_ALIGNED_FREE(ctx); + LM_GGML_FREE(ctx); } const char * lm_gguf_type_name(enum lm_gguf_type type) { @@ -21303,7 +21651,7 @@ void lm_gguf_set_arr_data(struct lm_gguf_context * ctx, const char * key, enum l ctx->kv[idx].type = LM_GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = type; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = LM_GGML_MALLOC(n*lm_gguf_type_size(type)); + ctx->kv[idx].value.arr.data = LM_GGML_CALLOC(n, lm_gguf_type_size(type)); memcpy(ctx->kv[idx].value.arr.data, data, n*lm_gguf_type_size(type)); } @@ -21313,7 +21661,7 @@ void lm_gguf_set_arr_str(struct lm_gguf_context * ctx, const char * key, const c ctx->kv[idx].type = LM_GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = LM_GGUF_TYPE_STRING; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = LM_GGML_MALLOC(n*sizeof(struct lm_gguf_str)); + ctx->kv[idx].value.arr.data = LM_GGML_CALLOC(n, sizeof(struct lm_gguf_str)); for (int i = 0; i < n; i++) { struct lm_gguf_str * str = &((struct lm_gguf_str *)ctx->kv[idx].value.arr.data)[i]; str->n = strlen(data[i]); @@ -21340,7 +21688,7 @@ void lm_gguf_set_kv(struct lm_gguf_context * ctx, struct lm_gguf_context * src) case LM_GGUF_TYPE_ARRAY: { if (src->kv[i].value.arr.type == LM_GGUF_TYPE_STRING) { - const char ** data = LM_GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *)); + const char ** data = LM_GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *)); for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { data[j] = ((struct lm_gguf_str *)src->kv[i].value.arr.data)[j].data; } @@ -21360,6 +21708,10 @@ void lm_gguf_set_kv(struct lm_gguf_context * ctx, struct lm_gguf_context * src) void lm_gguf_add_tensor( struct lm_gguf_context * ctx, const struct lm_ggml_tensor * tensor) { + if (lm_gguf_find_tensor(ctx, tensor->name) != -1) { + LM_GGML_ASSERT(false && "duplicated tensor name"); + } + const int idx = ctx->header.n_tensors; ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct lm_gguf_tensor_info)); @@ -21428,7 +21780,7 @@ struct lm_gguf_buf { static struct lm_gguf_buf lm_gguf_buf_init(size_t size) { struct lm_gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : LM_GGML_MALLOC(size), + /*buf.data =*/ size == 0 ? NULL : LM_GGML_CALLOC(1, size), /*buf.size =*/ size, /*buf.offset =*/ 0, }; diff --git a/cpp/ggml.h b/cpp/ggml.h index 4a8dafc..6a348d2 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -475,6 +475,7 @@ extern "C" { LM_GGML_OP_LEAKY_RELU, LM_GGML_OP_FLASH_ATTN, + LM_GGML_OP_FLASH_ATTN_EXT, LM_GGML_OP_FLASH_FF, LM_GGML_OP_FLASH_ATTN_BACK, LM_GGML_OP_SSM_CONV, @@ -762,6 +763,8 @@ extern "C" { // use this to compute the memory overhead of a tensor LM_GGML_API size_t lm_ggml_tensor_overhead(void); + LM_GGML_API bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t nbytes); + // main LM_GGML_API struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params); @@ -1720,6 +1723,25 @@ extern "C" { struct lm_ggml_tensor * v, bool masked); +#define LM_GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = LM_GGML_PAD(n_batch, LM_GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + LM_GGML_API struct lm_ggml_tensor * lm_ggml_flash_attn_ext( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * q, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * mask, + float scale); + + LM_GGML_API void lm_ggml_flash_attn_ext_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_flash_attn_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * q, diff --git a/cpp/llama.cpp b/cpp/llama.cpp index 64406f6..b5668e2 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #include #include @@ -107,7 +108,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 - // // logging // @@ -327,6 +327,7 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, @@ -403,6 +404,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, @@ -1854,7 +1856,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1944,6 +1946,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -2047,8 +2050,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2125,7 +2128,8 @@ struct llama_vocab { ttype type; }; - enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; std::unordered_map token_to_id; std::vector id_to_token; @@ -2346,11 +2350,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, lm_ggml_type type_k, lm_ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2361,8 +2368,9 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; - // TODO: support mixed reccurent Transformer architectues + // TODO: support mixed recurrent Transformer architectures // NOTE: (!a || b) is a logical implication (a -> b) LM_GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); LM_GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); @@ -2573,6 +2581,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + lm_ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -2893,6 +2905,7 @@ namespace GGUFMeta { case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; } return "unknown"; } @@ -2904,13 +2917,16 @@ namespace GGUFMeta { __func__, override_type_to_str(ovrd->tag), ovrd->key); switch (ovrd->tag) { case LLAMA_KV_OVERRIDE_TYPE_BOOL: { - LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); } break; case LLAMA_KV_OVERRIDE_TYPE_INT: { - LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); } break; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { - LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); } break; default: // Shouldn't be possible to end up here, but just in case... @@ -2929,7 +2945,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { - target = ovrd->bool_value; + target = ovrd->val_bool; return true; } return false; @@ -2939,7 +2955,7 @@ namespace GGUFMeta { static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { - target = ovrd->int_value; + target = ovrd->val_i64; return true; } return false; @@ -2949,7 +2965,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { - target = ovrd->float_value; + target = ovrd->val_f64; return true; } return false; @@ -2958,12 +2974,11 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { - (void)target; - (void)ovrd; - if (!ovrd) { return false; } - // Currently, we should never end up here so it would be a bug if we do. - throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", - ovrd ? ovrd->key : "NULL")); + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; } static bool set(const lm_gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { @@ -2996,6 +3011,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool check_tensors; llama_files files; llama_ftype ftype; @@ -3010,9 +3026,13 @@ struct llama_model_loader { lm_ggml_tensor * tensor; - llama_tensor_weight(uint16_t idx, const char * name, const struct lm_gguf_context * lm_gguf_ctx, lm_ggml_tensor * tensor) : idx(idx), tensor(tensor) { + llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct lm_gguf_context * lm_gguf_ctx, lm_ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = lm_gguf_find_tensor(lm_gguf_ctx, name); offs = lm_gguf_get_data_offset(lm_gguf_ctx) + lm_gguf_get_tensor_offset(lm_gguf_ctx, tensor_idx); + + if (offs + lm_ggml_nbytes(tensor) < offs || offs + lm_ggml_nbytes(tensor) > file->size) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); + } } }; std::vector weights; @@ -3025,7 +3045,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3051,15 +3071,15 @@ struct llama_model_loader { get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(0, cur->name, meta, cur); + weights.emplace_back(files.back().get(), 0, cur->name, meta, cur); } - files.emplace_back(new llama_file(fname.c_str(), "rb")); - contexts.emplace_back(ctx); - uint16_t n_split = 0; get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); @@ -3093,12 +3113,13 @@ struct llama_model_loader { throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); } + files.emplace_back(new llama_file(split_path, "rb")); + contexts.emplace_back(ctx); + // Save tensors data offset info of the shard. for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(idx, cur->name, ctx_gguf, cur); + weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur); } - files.emplace_back(new llama_file(split_path, "rb")); - contexts.emplace_back(ctx); lm_gguf_free(ctx_gguf); } @@ -3121,9 +3142,17 @@ struct llama_model_loader { fver = (enum llama_fver) lm_gguf_get_version(meta); + std::set tensor_names; for (auto & w : weights) { n_elements += lm_ggml_nelements(w.tensor); n_bytes += lm_ggml_nbytes(w.tensor); + // make sure there is no duplicated tensor names + const std::string name(w.tensor->name); + auto found = tensor_names.find(name); + if (found != tensor_names.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); + } + tensor_names.insert(name); } LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", @@ -3229,6 +3258,7 @@ struct llama_model_loader { } this->use_mmap = use_mmap; + this->check_tensors = check_tensors; } ~llama_model_loader() { @@ -3308,6 +3338,10 @@ struct llama_model_loader { return nullptr; } + const llama_tensor_weight * get_weight(int i) const { + return get_weight(get_tensor_name(i)); + } + const llama_tensor_weight & require_weight(const char * name) const { const llama_tensor_weight * weight = get_weight(name); if (!weight) { @@ -3483,6 +3517,10 @@ struct llama_model_loader { file->seek(w.offs, SEEK_SET); file->read_raw(cur->data, lm_ggml_nbytes(cur)); } + + if (check_tensors && !lm_ggml_validate_row_data(cur->type, cur->data, lm_ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", lm_ggml_get_name(cur))); + } } size_t size_done = 0; @@ -3499,6 +3537,8 @@ struct llama_model_loader { LM_GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector> read_buf; + std::vector>> validation_result; + for (struct lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur != NULL; cur = lm_ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(lm_ggml_get_name(cur)); if (weight == nullptr) { @@ -3520,37 +3560,66 @@ struct llama_model_loader { if (bufs_mmap.count(weight->idx)) { buf_mmap = bufs_mmap.at(weight->idx); } + uint8_t * data = (uint8_t *) mapping->addr + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, lm_ggml_validate_row_data(cur->type, data, n_size)); + })); + } + LM_GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated if (buf_mmap && cur->data == nullptr) { - lm_ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs); + lm_ggml_backend_tensor_alloc(buf_mmap, cur, data); if (lmlocks) { const auto & lmlock = lmlocks->at(weight->idx); - lmlock->grow_to(weight->offs + lm_ggml_nbytes(cur)); + lmlock->grow_to(weight->offs + n_size); } auto & mmap_used = mmaps_used[weight->idx]; mmap_used.first = std::min(mmap_used.first, weight->offs); mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); } else { - lm_ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size); + lm_ggml_backend_tensor_set(cur, data, 0, n_size); } } else { LM_GGML_ASSERT(weight->idx < files.size()); const auto & file = files.at(weight->idx); if (lm_ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); - file->read_raw(cur->data, lm_ggml_nbytes(cur)); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, lm_ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } } else { - read_buf.resize(lm_ggml_nbytes(cur)); + read_buf.resize(n_size); file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), lm_ggml_nbytes(cur)); + file->read_raw(read_buf.data(), n_size); lm_ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !lm_ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", lm_ggml_get_name(cur))); + } } } size_done += n_size; } + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, lm_ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + // check if this is the last call and do final cleanup if (size_done >= size_data) { // unmap offloaded tensors and metadata @@ -4144,7 +4213,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -4167,11 +4236,13 @@ static void llm_load_vocab( // determine vocab type { - std::string tokenizer_name; + std::string tokenizer_model; + std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - if (tokenizer_name == "no_vocab") { + if (tokenizer_model == "no_vocab") { vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens @@ -4185,7 +4256,7 @@ static void llm_load_vocab( vocab.linefeed_id = -1; return; - } else if (tokenizer_name == "llama") { + } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens @@ -4230,9 +4301,27 @@ static void llm_load_vocab( if (add_space_prefix_keyidx != -1) { vocab.add_space_prefix = lm_gguf_get_val_bool(ctx, add_space_prefix_keyidx); } // The default value of add_space_prefix is true. - } else if (tokenizer_name == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else if (tokenizer_model == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = 100; + vocab.special_sep_id = 102; + vocab.special_pad_id = 0; + vocab.special_cls_id = 101; + vocab.special_mask_id = 103; + vocab.add_space_prefix = false; + } else { + if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + return; + } // read bpe merges and populate bpe ranks const int merges_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -4266,23 +4355,50 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; - } else if (tokenizer_name == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; + } - // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = 100; - vocab.special_sep_id = 102; - vocab.special_pad_id = 0; - vocab.special_cls_id = 101; - vocab.special_mask_id = 103; - vocab.add_space_prefix = false; + // for now, only BPE models have pre-tokenizers + if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "default") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + } else if ( + tokenizer_pre == "deepseek-llm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + } else if ( + tokenizer_pre == "deepseek-coder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + } else if ( + tokenizer_pre == "falcon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - - vocab.type = LLAMA_VOCAB_TYPE_SPM; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } } @@ -5977,7 +6093,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -6106,37 +6222,47 @@ static struct lm_ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct lm_ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct lm_ggml_cgraph * graph, struct lm_ggml_tensor * k_cur, struct lm_ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); LM_GGML_ASSERT(kv.size == n_ctx); - // compute the transposed [n_tokens, n_embd] V matrix - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct lm_ggml_tensor * v_cur_t = lm_ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); - struct lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (lm_ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - struct lm_ggml_tensor * v_cache_view = lm_ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*lm_ggml_element_size(kv.v_l[il]), - (kv_head)*lm_ggml_element_size(kv.v_l[il])); + // note: storing RoPE-ed version of K in the KV cache + lm_ggml_build_forward_expand(graph, lm_ggml_cpy(ctx, k_cur, k_cache_view)); + + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + + struct lm_ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = lm_ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*lm_ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = lm_ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*lm_ggml_element_size(kv.v_l[il]), + (kv_head)*lm_ggml_element_size(kv.v_l[il])); + + v_cur = lm_ggml_transpose(ctx, v_cur); + } cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - lm_ggml_build_forward_expand(graph, lm_ggml_cpy(ctx, k_cur, k_cache_view)); - lm_ggml_build_forward_expand(graph, lm_ggml_cpy(ctx, v_cur_t, v_cache_view)); + lm_ggml_build_forward_expand(graph, lm_ggml_cpy(ctx, v_cur, v_cache_view)); } static struct lm_ggml_tensor * llm_build_norm( @@ -6356,11 +6482,11 @@ static struct lm_ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct lm_ggml_tensor * llm_build_kqv( struct lm_ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct lm_ggml_cgraph * graph, struct lm_ggml_tensor * wo, @@ -6368,12 +6494,12 @@ static struct lm_ggml_tensor * llm_build_kqv( struct lm_ggml_tensor * q_cur, struct lm_ggml_tensor * kq_mask, struct lm_ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6391,71 +6517,99 @@ static struct lm_ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + struct lm_ggml_tensor * cur; - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32); - } + if (cparams.flash_attn) { + LM_GGML_UNUSED(model); + LM_GGML_UNUSED(n_ctx); - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + // note: if this assert triggers, then some check has failed earlier + // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention + LM_GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - //try from phi2 - //lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32); + // split cached v into n_head heads (not transposed) + struct lm_ggml_tensor * v = + lm_ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + lm_ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + lm_ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - kq = lm_ggml_tanh(ctx, lm_ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = lm_ggml_scale(ctx, kq, 30); - } + cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32); + } + + cur = lm_ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32); + } + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32); + + kq = lm_ggml_tanh(ctx, lm_ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = lm_ggml_scale(ctx, kq, 30); + } #if defined(LM_GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in lm_ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to lm_ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = lm_ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.use_alibi) { + kq = lm_ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = lm_ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = lm_ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = lm_ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = lm_ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = lm_ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = lm_ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = lm_ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + { + kq = lm_ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } - LM_GGML_ASSERT(kv.size == n_ctx); + LM_GGML_ASSERT(kv.size == n_ctx); - // split cached v into n_head heads - struct lm_ggml_tensor * v = - lm_ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - lm_ggml_element_size(kv.v_l[il])*n_ctx, - lm_ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + // split cached v into n_head heads + struct lm_ggml_tensor * v = + lm_ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + lm_ggml_element_size(kv.v_l[il])*n_ctx, + lm_ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - struct lm_ggml_tensor * cur = lm_ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = lm_ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } lm_ggml_build_forward_expand(graph, cur); @@ -6475,6 +6629,7 @@ static struct lm_ggml_tensor * llm_build_kv( struct lm_ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct lm_ggml_cgraph * graph, struct lm_ggml_tensor * wo, @@ -6484,7 +6639,6 @@ static struct lm_ggml_tensor * llm_build_kv( struct lm_ggml_tensor * q_cur, struct lm_ggml_tensor * kq_mask, struct lm_ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6498,12 +6652,12 @@ static struct lm_ggml_tensor * llm_build_kv( lm_ggml_build_forward_expand(graph, k_cur); lm_ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct lm_ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6545,6 +6699,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6591,6 +6747,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6705,15 +6862,31 @@ struct llm_build_context { lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - lm_ggml_tensor * view_v_src = lm_ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - lm_ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - lm_ggml_row_size(kv_self.v_l[il]->type, i)); + lm_ggml_tensor * view_v_src; + lm_ggml_tensor * view_v_dst; - lm_ggml_tensor * view_v_dst = lm_ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - lm_ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - lm_ggml_row_size(kv_self.v_l[il]->type, id)); + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = lm_ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = lm_ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = lm_ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + lm_ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + lm_ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = lm_ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + lm_ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + lm_ggml_row_size(kv_self.v_l[il]->type, id)); + } lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst)); lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6743,20 +6916,26 @@ struct llm_build_context { struct lm_ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, n_tokens); + lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); lm_ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return flash_attn ? lm_ggml_cast(ctx0, lctx.inp_KQ_mask, LM_GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct lm_ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, n_kv); + struct lm_ggml_tensor * build_inp_KQ_pos(bool causal = true) { + if (causal) { + lctx.inp_KQ_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, n_kv); + } else { + // TODO: this will be needed for ALiBi-based BERT models + // https://github.com/ggerganov/llama.cpp/pull/6826 + lctx.inp_KQ_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, n_tokens); + } cb(lctx.inp_KQ_pos, "KQ_pos", -1); lm_ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return flash_attn ? lm_ggml_cast(ctx0, lctx.inp_KQ_pos, LM_GGML_TYPE_F16) : lctx.inp_KQ_pos; } struct lm_ggml_tensor * build_inp_mean() { @@ -6862,9 +7041,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7002,9 +7181,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7109,9 +7288,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7229,9 +7408,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7354,9 +7533,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7506,9 +7685,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7618,9 +7797,9 @@ struct llm_build_context { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7822,9 +8001,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7918,9 +8097,9 @@ struct llm_build_context { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8211,9 +8390,9 @@ struct llm_build_context { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8342,14 +8521,15 @@ struct llm_build_context { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8491,9 +8671,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8609,9 +8789,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8722,9 +8902,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8836,9 +9016,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8991,9 +9171,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9108,9 +9288,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9221,9 +9401,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct lm_ggml_tensor * sa_out = cur; @@ -9324,9 +9504,9 @@ struct llm_build_context { Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9431,9 +9611,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9547,9 +9727,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9664,9 +9844,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9794,9 +9974,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9915,9 +10095,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -10034,9 +10214,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10324,9 +10504,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10455,9 +10635,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10884,7 +11064,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; LM_GGML_ASSERT(lctx.inp_KQ_pos); @@ -11266,7 +11448,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(32u, LM_GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(256u, LM_GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -11434,6 +11616,10 @@ static int llama_decode_internal( } } + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + lm_ggml_backend_sched_reset(lctx.sched); + return 0; } @@ -11459,7 +11645,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -11783,7 +11971,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } case LLAMA_VOCAB_TYPE_BPE: { LM_GGML_ASSERT(false); - return unicode_utf8_to_byte(token_data.text); + return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after LM_GGML_ASSERT? } case LLAMA_VOCAB_TYPE_WPM: { LM_GGML_ASSERT(false); @@ -12005,7 +12193,79 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - auto word_collection = bpe_gpt2_preprocess(text); + + std::vector word_collection; + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_BPE: + switch (vocab.type_pre) { + case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + word_collection = unicode_regex_split(text, { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + + // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + "\\s+$", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_FALCON: + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_MPT: + // TODO: MPT pre-tokenization regexes are unknown + // the following are close, but not exact. run the following: + // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf + LM_GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); + word_collection = unicode_regex_split(text, { + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_STARCODER: + case LLAMA_VOCAB_PRE_TYPE_GPT2: + word_collection = unicode_regex_split(text, { + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + default: + // default regex for BPE tokenization pre-processing + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }); + break; + } + break; + default: + LM_GGML_ASSERT(false); + break; + } symbols_final.clear(); @@ -12132,145 +12392,6 @@ struct llm_tokenizer_bpe { work_queue.push(bigram); } - std::vector bpe_gpt2_preprocess(const std::string & text) { - std::vector bpe_words; - std::vector bpe_encoded_words; - - std::string token = ""; - // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ - bool collecting_numeric = false; - bool collecting_letter = false; - bool collecting_special = false; - bool collecting_whitespace_lookahead = false; - bool collecting = false; - - std::vector text_utf; - text_utf.reserve(text.size()); - bpe_words.reserve(text.size()); - bpe_encoded_words.reserve(text.size()); - - const auto cpts = unicode_cpts_from_utf8(text); - for (size_t i = 0; i < cpts.size(); ++i) - text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); - - for (int i = 0; i < (int)text_utf.size(); i++) { - const std::string & utf_char = text_utf[i]; - bool split_condition = false; - int bytes_remain = text_utf.size() - i; - // forward backward lookups - const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; - const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; - - // handling contractions - if (!split_condition && bytes_remain >= 2) { - // 's|'t|'m|'d - if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { - split_condition = true; - } - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next; - bpe_words.emplace_back(token); - token = ""; - i++; - continue; - } - } - if (!split_condition && bytes_remain >= 3) { - // 're|'ve|'ll - if (utf_char == "\'" && ( - (utf_char_next == "r" && utf_char_next_next == "e") || - (utf_char_next == "v" && utf_char_next_next == "e") || - (utf_char_next == "l" && utf_char_next_next == "l")) - ) { - split_condition = true; - } - if (split_condition) { - // current token + next token can be defined - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next + utf_char_next_next; - bpe_words.emplace_back(token); // the contraction - token = ""; - i += 2; - continue; - } - } - - if (!split_condition && !collecting) { - if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { - collecting_letter = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - collecting_numeric = true; - collecting = true; - } - else if ( - ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) - ) { - collecting_special = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { - collecting_whitespace_lookahead = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { - split_condition = true; - } - } - else if (!split_condition && collecting) { - if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { - split_condition = true; - } - else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { - split_condition = true; - } - else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { - split_condition = true; - } - else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - split_condition = true; - } - } - - if (utf_char_next == "") { - split_condition = true; // final - token += utf_char; - } - - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); - } - token = utf_char; - collecting = false; - collecting_letter = false; - collecting_numeric = false; - collecting_special = false; - collecting_whitespace_lookahead = false; - } - else { - token += utf_char; - } - } - - for (std::string & word : bpe_words) { - std::string encoded_token = ""; - for (char & c : word) { - encoded_token += unicode_byte_to_utf8(c); - } - bpe_encoded_words.emplace_back(encoded_token); - } - - return bpe_encoded_words; - } - const llama_vocab & vocab; std::vector symbols; @@ -12590,7 +12711,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } break; case LLAMA_VOCAB_TYPE_BPE: { - if (add_special && vocab.special_add_bos == 1) { + if (add_special && vocab.special_add_bos != 0) { LM_GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -14362,14 +14483,20 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_ } static size_t llama_tensor_quantize_internal(enum lm_ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { - std::mutex mutex; - int64_t counter = 0; - size_t new_size = 0; if (nthread < 2) { // single-thread - return lm_ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + size_t new_size = lm_ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + if (!lm_ggml_validate_row_data(new_type, new_data, new_size)) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; } - auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size, + + std::mutex mutex; + int64_t counter = 0; + size_t new_size = 0; + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix]() { const int64_t nrows_per_chunk = chunk_size / n_per_row; size_t local_size = 0; @@ -14384,7 +14511,17 @@ static size_t llama_tensor_quantize_internal(enum lm_ggml_type new_type, const f } lock.unlock(); const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); - local_size += lm_ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + size_t this_size = lm_ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data + const size_t row_size = lm_ggml_row_size(new_type, n_per_row); + void * this_data = (char *) new_data + first_row * row_size; + if (!lm_ggml_validate_row_data(new_type, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } } }; for (int it = 0; it < nthread - 1; ++it) { @@ -14393,6 +14530,9 @@ static size_t llama_tensor_quantize_internal(enum lm_ggml_type new_type, const f compute(); for (auto & w : workers) { w.join(); } workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } return new_size; } @@ -14455,7 +14595,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -14493,11 +14633,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s for (auto & o : overrides) { if (o.key[0] == 0) break; if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - lm_gguf_set_val_f32(ctx_out, o.key, o.float_value); + lm_gguf_set_val_f32(ctx_out, o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - lm_gguf_set_val_i32(ctx_out, o.key, o.int_value); + lm_gguf_set_val_i32(ctx_out, o.key, o.val_i64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - lm_gguf_set_val_bool(ctx_out, o.key, o.bool_value); + lm_gguf_set_val_bool(ctx_out, o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + lm_gguf_set_val_str(ctx_out, o.key, o.val_str); } else { LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); } @@ -14539,26 +14681,74 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector> work; std::vector> f32_conv_buf; + uint16_t n_split = 1; + // Assume split index is continuous + if (params->keep_split) { + for (int i = 0; i < ml.n_tensors; ++i) { + n_split = std::max(uint16_t(ml.get_weight(i)->idx+1), n_split); + } + } + std::vector ctx_outs(n_split, NULL); + ctx_outs[0] = ctx_out; + // populate the original tensors so we get an initial meta data for (int i = 0; i < ml.n_tensors; ++i) { - const struct lm_ggml_tensor * meta = ml.get_tensor_meta(i); - lm_gguf_add_tensor(ctx_out, meta); + auto weight = ml.get_weight(i); + uint16_t i_split = params->keep_split ? weight->idx : 0; + struct lm_ggml_tensor * tensor = weight->tensor; + if (ctx_outs[i_split] == NULL) { + ctx_outs[i_split] = lm_gguf_init_empty(); + } + lm_gguf_add_tensor(ctx_outs[i_split], tensor); } - std::ofstream fout(fname_out, std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors - - const size_t meta_size = lm_gguf_get_meta_size(ctx_out); + // Set split info if needed + if (n_split > 1) { + for (size_t i = 0; i < ctx_outs.size(); ++i) { + lm_gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); + lm_gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); + lm_gguf_set_val_i32(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + } + } - LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); + int cur_split = -1; + std::ofstream fout; + auto close_ofstream = [&]() { + // Write metadata and close file handler + if (fout.is_open()) { + fout.seekp(0); + std::vector data(lm_gguf_get_meta_size(ctx_outs[cur_split])); + lm_gguf_get_meta_data(ctx_outs[cur_split], data.data()); + fout.write((const char *) data.data(), data.size()); + fout.close(); + } + }; + auto new_ofstream = [&](int index) { + cur_split = index; + LM_GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized lm_gguf_context"); + std::string fname = fname_out; + if (params->keep_split) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), fname_out.c_str(), cur_split, n_split); + fname = std::string(split_path); + } - // placeholder for the meta data - ::zeros(fout, meta_size); + fout = std::ofstream(fname, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + const size_t meta_size = lm_gguf_get_meta_size(ctx_outs[cur_split]); + // placeholder for the meta data + ::zeros(fout, meta_size); + }; const auto tn = LLM_TN(model.arch); - + new_ofstream(0); for (int i = 0; i < ml.n_tensors; ++i) { - struct lm_ggml_tensor * tensor = ml.get_tensor_meta(i); + auto weight = ml.get_weight(i); + struct lm_ggml_tensor * tensor = weight->tensor; + if (weight->idx != cur_split && params->keep_split) { + close_ofstream(); + new_ofstream(weight->idx); + } const std::string name = lm_ggml_get_name(tensor); @@ -14713,26 +14903,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s total_size_new += new_size; // update the gguf meta data as we go - lm_gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - lm_gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + lm_gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type); + lm_gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size); // write tensor data + padding fout.write((const char *) new_data, new_size); zeros(fout, LM_GGML_PAD(new_size, align) - new_size); } - - // go back to beginning of file and write the updated meta data - { - fout.seekp(0); - std::vector data(lm_gguf_get_meta_size(ctx_out)); - lm_gguf_get_meta_data(ctx_out, data.data()); - fout.write((const char *) data.data(), data.size()); + close_ofstream(); + for (auto & c:ctx_outs) { + lm_gguf_free(c); } - fout.close(); - - lm_gguf_free(ctx_out); - LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); @@ -14776,7 +14958,7 @@ static int llama_apply_lora_from_file_internal( std::unique_ptr ml; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr)); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr)); ml->init_mappings(/*prefetch*/ false); // no prefetching } @@ -15035,6 +15217,7 @@ struct llama_model_params llama_model_default_params() { /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, + /*.check_tensors =*/ false, }; #ifdef LM_GGML_USE_METAL @@ -15071,6 +15254,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -15088,6 +15272,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.quantize_output_tensor =*/ true, /*.only_copy =*/ false, /*.pure =*/ false, + /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, }; @@ -15236,6 +15421,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -15243,12 +15429,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least LM_GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. lm_ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < LM_GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than LM_GGML_KQ_MASK_PAD - increasing to %d\n", __func__, LM_GGML_KQ_MASK_PAD); + cparams.n_batch = LM_GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : @@ -15280,6 +15474,23 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && hparams.use_alibi) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); + cparams.flash_attn = false; + } + + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + +#ifdef LM_GGML_USE_HIPBLAS + if (cparams.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); + cparams.flash_attn = false; + } +#endif + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } @@ -15287,6 +15498,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -15415,7 +15627,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -16014,6 +16226,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -16031,10 +16244,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } @@ -16092,6 +16309,8 @@ struct llama_data_file_context : llama_data_context { * */ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + llama_synchronize(ctx); + // copy rng { std::ostringstream rng_ss; @@ -16178,11 +16397,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -16195,7 +16416,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data lm_ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16244,6 +16465,8 @@ size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) { // Sets the state reading from the specified source address size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { + llama_synchronize(ctx); + const uint8_t * inp = src; // set rng @@ -16326,11 +16549,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + LM_GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state @@ -16340,6 +16567,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -16351,7 +16580,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { lm_ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16373,8 +16602,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { LM_GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -16548,6 +16775,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) } static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + const auto & kv_self = ctx->kv_self; LM_GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -16632,28 +16861,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } - // For the values, they are transposed, so we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (int il = 0; il < (int)n_layer; ++il) { - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - data_ctx.write(&v_type_i, sizeof(v_type_i)); + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write element size - const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type); - data_ctx.write(&v_size_el, sizeof(v_size_el)); + // Write row size of value + const size_t v_size_row = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - tmp_buf.resize(range_size * v_size_el); - lm_ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + tmp_buf.resize(range_size * v_size_row); + lm_ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + lm_ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } } return data_ctx.get_size_written(); @@ -16665,6 +16915,8 @@ size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_s } size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { + llama_synchronize(ctx); + auto & kv_self = ctx->kv_self; LM_GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -16776,41 +17028,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } - // For each layer, read the values for each cell (transposed) - for (int il = 0; il < (int)n_layer; ++il) { - // Read type of value - int32_t v_type_i_ref; - memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); - inp += sizeof(v_type_i_ref); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return 0; - } + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } - // Read element size of value - size_t v_size_el_ref; - memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); - inp += sizeof(v_size_el_ref); - const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); - return 0; - } + // Read row size of value + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - lm_ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + if (cell_count) { + // Read and set the values for the whole cell range + lm_ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + lm_ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } } const size_t nread = inp - src; + return nread; } @@ -17606,6 +17892,11 @@ const char * llama_print_system_info(void) { s += "SSSE3 = " + std::to_string(lm_ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(lm_ggml_cpu_has_vsx()) + " | "; s += "MATMUL_INT8 = " + std::to_string(lm_ggml_cpu_has_matmul_int8()) + " | "; +#ifdef LM_GGML_USE_LLAMAFILE + s += "LLAMAFILE = 1 | "; +#else + s += "LLAMAFILE = 0 | "; +#endif return s.c_str(); } diff --git a/cpp/llama.h b/cpp/llama.h index b92740f..faab40c 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 @@ -69,6 +69,18 @@ extern "C" { LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece }; + // pre-tokenization types + enum llama_vocab_pre_type { + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + }; + // note: these values should be synchronized with lm_ggml_rope // TODO: maybe move this enum to ggml.h (lm_ggml_rope_type) enum llama_rope_type { @@ -159,7 +171,7 @@ extern "C" { bool sorted; } llama_token_data_array; - typedef bool (*llama_progress_callback)(float progress, void *ctx); + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode // A llama_batch object can contain input about one or many sequences @@ -195,15 +207,19 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, LLAMA_KV_OVERRIDE_TYPE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_STR, }; struct llama_model_kv_override { - char key[128]; enum llama_model_kv_override_type tag; + + char key[128]; + union { - int64_t int_value; - double float_value; - bool bool_value; + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; }; }; @@ -232,9 +248,10 @@ extern "C" { const struct llama_model_kv_override * kv_overrides; // Keep the booleans together to avoid misalignment during copy-by-value. - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool check_tensors; // validate model tensor data }; struct llama_context_params { @@ -270,6 +287,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -288,6 +306,7 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; @@ -524,7 +543,7 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); diff --git a/cpp/log.h b/cpp/log.h index 5d6df5f..d5d4517 100644 --- a/cpp/log.h +++ b/cpp/log.h @@ -234,7 +234,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std:: // INTERNAL, DO NOT USE // USE LOG() INSTEAD // -#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER) +#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__) #define LOG_IMPL(str, ...) \ do { \ if (LOG_TARGET != nullptr) \ @@ -257,7 +257,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std:: // INTERNAL, DO NOT USE // USE LOG_TEE() INSTEAD // -#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER) +#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__) #define LOG_TEE_IMPL(str, ...) \ do { \ if (LOG_TARGET != nullptr) \ diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp index d0bd85e..a9b6b32 100644 --- a/cpp/sampling.cpp +++ b/cpp/sampling.cpp @@ -68,7 +68,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); + seed = std::random_device{}(); } ctx->rng.seed(seed); } diff --git a/cpp/sgemm.cpp b/cpp/sgemm.cpp index 41c60ba..1378e94 100644 --- a/cpp/sgemm.cpp +++ b/cpp/sgemm.cpp @@ -50,7 +50,6 @@ #pragma GCC diagnostic ignored "-Wignored-attributes" #include "sgemm.h" -#include #include "ggml-impl.h" #include "ggml-quants.h" @@ -243,23 +242,23 @@ template <> inline __m512 load(const lm_ggml_fp16_t *p) { template class tinyBLAS { public: - tinyBLAS(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == LM_GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: mc = 5; @@ -409,27 +408,27 @@ class tinyBLAS { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; - for (int l = 0; l < k; l += KN) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; l += KN) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(load(A + lda * (ii + i) + l), load(B + ldb * (jj + j) + l), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -437,10 +436,10 @@ class tinyBLAS { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -452,23 +451,23 @@ class tinyBLAS { template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int k, - const TA *A, int lda, - const block_q8_0 *B, int ldb, - float *C, int ldc, + tinyBLAS_Q0_ARM(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == LM_GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; nc = 3; @@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), @@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM { load_hi(B + ldb * (jj + j) + l))), unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM { const TA *const A; const block_q8_0 *const B; float *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM { template class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX2(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS_Q0_AVX2(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == LM_GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: mc = 4; @@ -714,22 +713,22 @@ class tinyBLAS_Q0_AVX2 { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), @@ -737,8 +736,8 @@ class tinyBLAS_Q0_AVX2 { _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C, - int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, assert(ldc >= m); assert(nth > 0); assert(ith < nth); - assert(1ll * lda * m <= 0x7fffffff); - assert(1ll * ldb * n <= 0x7fffffff); - assert(1ll * ldc * n <= 0x7fffffff); if (Ctype != LM_GGML_TYPE_F32) return false; diff --git a/cpp/sgemm.h b/cpp/sgemm.h index da23b20..f29747d 100644 --- a/cpp/sgemm.h +++ b/cpp/sgemm.h @@ -1,11 +1,13 @@ #pragma once +#include #include #ifdef __cplusplus extern "C" { #endif -bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, - void *, int, int, int, int, int, int, int); +bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, + const void *, int64_t, void *, int64_t, int, int, + int, int, int, int); #ifdef __cplusplus } diff --git a/cpp/unicode-data.cpp b/cpp/unicode-data.cpp index 22f8b0f..e6bafb3 100644 --- a/cpp/unicode-data.cpp +++ b/cpp/unicode-data.cpp @@ -1,4 +1,4 @@ -#include "unicode-data.h" +#include "unicode-data.h" #include #include diff --git a/cpp/unicode-data.h b/cpp/unicode-data.h index b99500b..cb9dd8a 100644 --- a/cpp/unicode-data.h +++ b/cpp/unicode-data.h @@ -12,5 +12,5 @@ extern const std::vector> unicode_ranges_accent_ma extern const std::vector> unicode_ranges_punctuation; extern const std::vector> unicode_ranges_symbol; extern const std::vector> unicode_ranges_control; -extern const std::multimap unicode_map_nfd; -extern const std::map unicode_map_lowercase; +extern const std::multimap unicode_map_nfd; +extern const std::map unicode_map_lowercase; diff --git a/cpp/unicode.cpp b/cpp/unicode.cpp index df8c5f5..f2ccda0 100644 --- a/cpp/unicode.cpp +++ b/cpp/unicode.cpp @@ -5,11 +5,14 @@ #include #include #include +#include #include #include #include #include #include +#include +#include static std::string unicode_cpts_to_utf8(const std::vector & cps) { std::string result; @@ -53,23 +56,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) offset += 4; return result; } - throw std::invalid_argument("invalid string"); + throw std::invalid_argument("failed to convert utf8 to codepoint"); } -static std::vector unicode_cpt_to_utf16(uint32_t cp) { - std::vector result; - if (/* 0x0000 <= cp && */ cp <= 0xffff) { - result.emplace_back(cp); - } - else if (0x10000 <= cp && cp <= 0x10ffff) { - result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); - result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); - } - else { - throw std::invalid_argument("invalid cpt"); - } - return result; -} +//static std::vector unicode_cpt_to_utf16(uint32_t cp) { +// std::vector result; +// if (/* 0x0000 <= cp && */ cp <= 0xffff) { +// result.emplace_back(cp); +// return result; +// } +// if (0x10000 <= cp && cp <= 0x10ffff) { +// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); +// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); +// return result; +// } +// throw std::invalid_argument("failed to convert codepoint to utf16"); +//} //static std::vector unicode_cpts_to_utf16(const std::vector & cps) { // std::vector result; @@ -80,28 +82,28 @@ static std::vector unicode_cpt_to_utf16(uint32_t cp) { // return result; //} -static uint32_t cpt_from_utf16(const std::vector & utf16, size_t & offset) { - assert(offset < utf16.size()); - if (((utf16[0] >> 10) << 10) != 0xd800) { - auto result = utf16[offset + 0]; - offset += 1; - return result; - } - - if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { - throw std::invalid_argument("invalid character"); - } - - auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); - offset += 2; - return result; -} +//static uint32_t unicode_cpt_from_utf16(const std::vector & utf16, size_t & offset) { +// assert(offset < utf16.size()); +// if (((utf16[0] >> 10) << 10) != 0xd800) { +// auto result = utf16[offset + 0]; +// offset += 1; +// return result; +// } +// +// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { +// throw std::invalid_argument("invalid character"); +// } +// +// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); +// offset += 2; +// return result; +//} //static std::vector unicode_cpts_from_utf16(const std::vector & utf16) { // std::vector result; // size_t offset = 0; // while (offset < utf16.size()) { -// result.push_back(cpt_from_utf16(utf16, offset)); +// result.push_back(unicode_cpt_from_utf16(utf16, offset)); // } // return result; //} @@ -194,34 +196,277 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } +static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { + std::wstring_convert> conv; + return conv.from_bytes(s); +} + +static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { + std::vector bpe_encoded_words; + for (const auto & word : bpe_words) { + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { + text_utf += unicode_cpt_to_utf8(utf_word[i]); + } + + std::string encoded_token; + for (char & c : text_utf) { + encoded_token += unicode_byte_to_utf8(c); + } + bpe_encoded_words.emplace_back(encoded_token); + } + return bpe_encoded_words; +} + +// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + size_t start = 0; + + const auto cpts = unicode_cpts_from_utf8(text); + + for (auto offset : offsets) { + std::string token; + + bool collecting_numeric = false; + bool collecting_letter = false; + bool collecting_special = false; + bool collecting_whitespace_lookahead = false; + bool collecting = false; + + std::vector text_utf; + text_utf.reserve(offset); + + for (size_t i = start; i < start + offset; ++i) { + text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); + } + + for (int i = 0; i < (int)text_utf.size(); i++) { + const std::string & utf_char = text_utf[i]; + bool split_condition = false; + int bytes_remain = text_utf.size() - i; + + // forward backward lookups + const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; + const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; + + // handling contractions + if (!split_condition && bytes_remain >= 2) { + // 's|'t|'m|'d + if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { + split_condition = true; + } + if (split_condition) { + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char + utf_char_next; + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + token = ""; + i++; + continue; + } + } + if (!split_condition && bytes_remain >= 3) { + // 're|'ve|'ll + if (utf_char == "\'" && ( + (utf_char_next == "r" && utf_char_next_next == "e") || + (utf_char_next == "v" && utf_char_next_next == "e") || + (utf_char_next == "l" && utf_char_next_next == "l")) + ) { + split_condition = true; + } + if (split_condition) { + // current token + next token can be defined + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char; + token += utf_char_next; + token += utf_char_next_next; + + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + token = ""; + i += 2; + continue; + } + } + + if (!split_condition && !collecting) { + if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { + collecting_letter = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + collecting_numeric = true; + collecting = true; + } + else if ( + ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || + (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) + ) { + collecting_special = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { + collecting_whitespace_lookahead = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { + split_condition = true; + } + } + else if (!split_condition && collecting) { + if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { + split_condition = true; + } + else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { + split_condition = true; + } + else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { + split_condition = true; + } + else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + split_condition = true; + } + } + + if (utf_char_next == "") { + split_condition = true; // final + token += utf_char; + } + + if (split_condition) { + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char; + collecting = false; + collecting_letter = false; + collecting_numeric = false; + collecting_special = false; + collecting_whitespace_lookahead = false; + } + else { + token += utf_char; + } + } + + start += offset; + } + + return bpe_offsets; +} + +// use std::wregex to split the text +static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { + std::wregex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); + std::wcregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::wcmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +// use std::regex to split the text +static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::regex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); + std::cregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::cmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::vector bpe_offsets; + + (void)(text); + (void)(regex_expr); + (void)(offsets); + // TODO: this implementation is actually wrong, uncomment and run: + // make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf + //if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { + // bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); + //} + + return bpe_offsets; +} + // // interface // std::string unicode_cpt_to_utf8(uint32_t cp) { std::string result; + if (/* 0x00 <= cp && */ cp <= 0x7f) { result.push_back(cp); + return result; } - else if (0x80 <= cp && cp <= 0x7ff) { + if (0x80 <= cp && cp <= 0x7ff) { result.push_back(0xc0 | ((cp >> 6) & 0x1f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x800 <= cp && cp <= 0xffff) { + if (0x800 <= cp && cp <= 0xffff) { result.push_back(0xe0 | ((cp >> 12) & 0x0f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x10000 <= cp && cp <= 0x10ffff) { + if (0x10000 <= cp && cp <= 0x10ffff) { result.push_back(0xf0 | ((cp >> 18) & 0x07)); result.push_back(0x80 | ((cp >> 12) & 0x3f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else { - throw std::invalid_argument("invalid codepoint"); - } - return result; + + throw std::invalid_argument("invalid codepoint"); } std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) { @@ -275,3 +520,167 @@ char32_t unicode_tolower(char32_t cp) { auto it = unicode_map_lowercase.find(cp); return it == unicode_map_lowercase.end() ? cp : it->second; } + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { + // unicode categories + static const std::map k_ucat_enum = { + { "\\p{N}", CODEPOINT_TYPE_DIGIT }, + { "\\p{L}", CODEPOINT_TYPE_LETTER }, + { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, + }; + + static const std::map k_ucat_cpt = { + { CODEPOINT_TYPE_DIGIT, 0xD1 }, + { CODEPOINT_TYPE_LETTER, 0xD2 }, + { CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, + }; + + static const std::map k_ucat_map = { + { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9 + { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + }; + + // compute collapsed codepoints only if needed by at least one regex + bool need_collapse = false; + for (auto & regex_expr : regex_exprs) { + // search for unicode categories + for (const auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + need_collapse = true; + break; + } + } + } + + const auto cpts = unicode_cpts_from_utf8(text); + + // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte + // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + std::string text_collapsed; + if (need_collapse) { + // collapse all unicode categories + text_collapsed.resize(cpts.size()); + + for (size_t i = 0; i < cpts.size(); ++i) { + // keep single-byte codepoints as is + if (cpts[i] < 128) { + text_collapsed[i] = cpts[i]; + continue; + } + + const int cpt_type = unicode_cpt_type(cpts[i]); + + if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { + text_collapsed[i] = k_ucat_cpt.at(cpt_type); + } else { + text_collapsed[i] = (char) 0xD0; // fallback + } + } + } + + std::vector bpe_offsets = { cpts.size() }; + + for (auto & regex_expr : regex_exprs) { + // first, see if we have an efficient custom regex implementation + auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); + + if (!tmp.empty()) { + bpe_offsets = std::move(tmp); + continue; + } + + // fallback to general-purpose std::regex / std::wregex + try { + // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category + // with the corresponding collapsed representation + bool use_collapsed = false; + for (auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + use_collapsed = true; + break; + } + } + + if (use_collapsed) { + // sanity-check that the original regex does not contain any non-ASCII characters + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + for (size_t i = 0; i < cpts_regex.size(); ++i) { + if (cpts_regex[i] >= 128) { + throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); + } + } + + // generate a collapsed representation of the regex + std::string regex_expr_collapsed; + + // track if we are inside [], because nested [] are not allowed + bool inside = false; + for (size_t i = 0; i < regex_expr.size(); ++i) { + if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { + regex_expr_collapsed += '['; + inside = true; + continue; + } + + if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { + regex_expr_collapsed += ']'; + inside = false; + continue; + } + + if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + regex_expr[i + 1] == 'p' && + regex_expr[i + 2] == '{' && + regex_expr[i + 4] == '}') { + const std::string pat = regex_expr.substr(i, 5); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i += 4; + continue; + } + } + + regex_expr_collapsed += regex_expr[i]; + } + + //printf("text_collapsed: %s\n", text_collapsed.c_str()); + //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); + bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); + } else { + // no unicode category used, we can use std::wregex directly + const std::wstring wtext = unicode_wstring_from_utf8(text); + const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + + //printf("text: %s\n", text.c_str()); + //printf("regex_expr: %s\n", regex_expr.c_str()); + bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); + } + } catch (std::regex_error & e) { + fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); + fprintf(stderr, "Regex error: %s\n", e.what()); + throw std::runtime_error("Failed to process regex"); + } + } + + std::vector bpe_words; + bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size + + size_t start = 0; + for (size_t & offset : bpe_offsets) { + bpe_words.emplace_back(); + for (size_t i = start; i < start + offset; ++i) { + bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); + } + start += offset; + } + + return unicode_byte_encoding_process(bpe_words); +} diff --git a/cpp/unicode.h b/cpp/unicode.h index 6a0be39..ce2bcef 100644 --- a/cpp/unicode.h +++ b/cpp/unicode.h @@ -24,5 +24,6 @@ int unicode_cpt_type(const std::string & utf8); std::string unicode_byte_to_utf8(uint8_t byte); uint8_t unicode_utf8_to_byte(const std::string & utf8); -// simple tolower that only implements one-to-one mapping, not one-to-many char32_t unicode_tolower(char32_t cp); + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index f56101c..98dfd23 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -8,7 +8,7 @@ PODS: - hermes-engine/Pre-built (= 0.72.3) - hermes-engine/Pre-built (0.72.3) - libevent (2.1.12) - - llama-rn (0.3.0-rc.17): + - llama-rn (0.3.0): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1261,7 +1261,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: 84baa8193bdd10582bfc9eee3b963bbcc9153e5b + llama-rn: 109c3ca266fc0d7f4d086ceccc1f955a19669270 RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 diff --git a/llama.cpp b/llama.cpp index 784e11d..a2ac89d 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 784e11dea1f5ce9638851b2b0dddb107e2a609c8 +Subproject commit a2ac89d6efb41b535778bfeaecaae8fe295b6ed3 diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 1b66805..966a1fd 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,6 +1,6 @@ ---- common.h.orig 2024-04-25 11:10:50 -+++ common.h 2024-04-25 11:10:51 -@@ -42,6 +42,17 @@ +--- common.h.orig 2024-05-04 13:24:18 ++++ common.h 2024-05-04 13:24:19 +@@ -44,6 +44,17 @@ int get_math_cpu_count(); int32_t get_num_physical_cores(); diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 76f577d..c123fc3 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,6 +1,6 @@ ---- ggml-metal.m.orig 2024-04-25 11:10:50 -+++ ggml-metal.m 2024-04-25 11:10:51 -@@ -311,7 +311,7 @@ +--- ggml-metal.m.orig 2024-05-04 13:24:18 ++++ ggml-metal.m 2024-05-04 13:24:19 +@@ -321,7 +321,7 @@ const bool try_metallib = true; #endif diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch index 1071558..ac51a0f 100644 --- a/scripts/llama.cpp.patch +++ b/scripts/llama.cpp.patch @@ -1,5 +1,5 @@ ---- llama.cpp.orig 2024-04-25 11:10:50 -+++ llama.cpp 2024-04-25 11:10:51 +--- llama.cpp.orig 2024-05-04 13:24:18 ++++ llama.cpp 2024-05-04 13:24:19 @@ -120,6 +120,17 @@ #define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) @@ -18,7 +18,7 @@ // // helpers // -@@ -1301,16 +1312,16 @@ +@@ -1303,16 +1314,16 @@ if (prefetch > 0) { // advise the kernel to preload the mapped memory