Skip to content

Commit

Permalink
feat: sync llama.cpp (#79)
Browse files Browse the repository at this point in the history
* feat: sync llama.cpp

* fix: fix submodule update - as part of llama.cpp sync

* chore: remove unnecessary comment

* chore(example): revert unnecessary changes

* feat: sync llama.cpp

* fix: remove tfs_z

ref: ggerganov/llama.cpp#10071

* fix(cpp): skip gpu device if n_gpu_layers <= 0

ref: ggerganov/llama.cpp#10132

---------

Co-authored-by: Jhen-Jie Hong <[email protected]>
  • Loading branch information
a-ghorbani and jhen0409 authored Nov 2, 2024
1 parent f35545b commit 1ca3044
Show file tree
Hide file tree
Showing 55 changed files with 21,474 additions and 13,484 deletions.
22 changes: 17 additions & 5 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ include_directories(${RNLLAMA_LIB_DIR})

set(
SOURCE_FILES
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/log.cpp

${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/ggml-alloc.c
${RNLLAMA_LIB_DIR}/ggml-backend.c
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
${RNLLAMA_LIB_DIR}/ggml.c
${RNLLAMA_LIB_DIR}/ggml-quants.c
${RNLLAMA_LIB_DIR}/common.cpp
${RNLLAMA_LIB_DIR}/grammar-parser.cpp
${RNLLAMA_LIB_DIR}/json.hpp
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
${RNLLAMA_LIB_DIR}/sampling.cpp
${RNLLAMA_LIB_DIR}/unicode-data.cpp
${RNLLAMA_LIB_DIR}/unicode.cpp
${RNLLAMA_LIB_DIR}/llama.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/sgemm.cpp
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/rn-llama.hpp
Expand Down Expand Up @@ -65,10 +67,20 @@ build_library("rnllama" "")

if (${ANDROID_ABI} STREQUAL "arm64-v8a")
# ARM64 targets
build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
build_library("rnllama_v8" "-march=armv8-a")

# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
# llama.cpp will deal with the cpu features
# build_library("rnllama_v8_7" "-march=armv8.7-a")
# TODO: Add support runtime check for cpu features
# At the moment runtime check is failing.

elseif (${ANDROID_ABI} STREQUAL "x86_64")
# x86_64 target
build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
Expand Down
60 changes: 46 additions & 14 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
if (!params.hasKey("model")) {
throw new IllegalArgumentException("Missing required parameter: model");
}
Log.d(NAME, "Setting log callback");
logToAndroid();
this.id = id;
this.context = initContext(
// String model,
Expand All @@ -53,6 +55,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
// boolean use_mmap,
params.hasKey("use_mmap") ? params.getBoolean("use_mmap") : true,
//boolean vocab_only,
params.hasKey("vocab_only") ? params.getBoolean("vocab_only") : false,
// String lora,
params.hasKey("lora") ? params.getString("lora") : "",
// float lora_scaled,
Expand Down Expand Up @@ -181,6 +185,10 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
// float min_p,
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
// float xtc_threshold,
params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
// float xtc_probability,
params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
// float tfs_z,
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
// float typical_p,
Expand Down Expand Up @@ -248,16 +256,34 @@ public void release() {

static {
Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
if (LlamaContext.isArm64V8a()) {
String cpuFeatures = LlamaContext.getCpuFeatures();
Log.d(NAME, "CPU features: " + cpuFeatures);

boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");

if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
String cpuFeatures = LlamaContext.getCpuFeatures();
Log.d(NAME, "CPU features: " + cpuFeatures);
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
boolean hasSve = cpuFeatures.contains("sve");
boolean hasI8mm = cpuFeatures.contains("i8mm");
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
Log.d(NAME, "- hasFp16: " + hasFp16);
Log.d(NAME, "- hasDotProd: " + hasDotProd);
Log.d(NAME, "- hasSve: " + hasSve);
Log.d(NAME, "- hasI8mm: " + hasI8mm);
Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);

// TODO: Add runtime check for cpu features
if (LlamaContext.isArm64V8a()) {
if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
} else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
} else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
} else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod");
} else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
Expand All @@ -270,14 +296,16 @@ public void release() {
Log.d(NAME, "Loading librnllama_v8.so");
System.loadLibrary("rnllama_v8");
}
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
// System.loadLibrary("rnllama_v8_7");
} else if (LlamaContext.isX86_64()) {
Log.d(NAME, "Loading librnllama_x86_64.so");
System.loadLibrary("rnllama_x86_64");
Log.d(NAME, "Loading librnllama_x86_64.so");
System.loadLibrary("rnllama_x86_64");
} else {
Log.d(NAME, "Loading default librnllama.so");
System.loadLibrary("rnllama");
Log.d(NAME, "Loading default librnllama.so");
System.loadLibrary("rnllama");
}
}
}

private static boolean isArm64V8a() {
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
Expand Down Expand Up @@ -316,6 +344,7 @@ protected static native long initContext(
int n_gpu_layers, // TODO: Support this
boolean use_mlock,
boolean use_mmap,
boolean vocab_only,
String lora,
float lora_scaled,
float rope_freq_base,
Expand Down Expand Up @@ -357,6 +386,8 @@ protected static native WritableMap doCompletion(
int top_k,
float top_p,
float min_p,
float xtc_threshold,
float xtc_probability,
float tfs_z,
float typical_p,
int seed,
Expand All @@ -373,4 +404,5 @@ protected static native WritableMap doCompletion(
protected static native WritableMap embedding(long contextPtr, String text);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
}
82 changes: 53 additions & 29 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// #include <android/asset_manager_jni.h>
#include <android/log.h>
#include <cstdlib>
#include <ctime>
#include <sys/sysinfo.h>
#include <string>
#include <thread>
Expand All @@ -21,6 +22,13 @@ static inline int min(int a, int b) {
return (a < b) ? a : b;
}

static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
}

extern "C" {

// Method to create WritableMap
Expand Down Expand Up @@ -139,14 +147,20 @@ Java_com_rnllama_LlamaContext_initContext(
jint n_gpu_layers, // TODO: Support this
jboolean use_mlock,
jboolean use_mmap,
jboolean vocab_only,
jstring lora_str,
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale
) {
UNUSED(thiz);

gpt_params defaultParams;
common_params defaultParams;

defaultParams.vocab_only = vocab_only;
if(vocab_only) {
defaultParams.warmup = false;
}

const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
defaultParams.model = model_path_chars;
Expand All @@ -159,7 +173,7 @@ Java_com_rnllama_LlamaContext_initContext(
int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

defaultParams.n_gpu_layers = n_gpu_layers;

Expand Down Expand Up @@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

std::vector<llama_chat_msg> chat;
std::vector<common_chat_msg> chat;

int messages_len = env->GetArrayLength(messages);
for (int i = 0; i < messages_len; i++) {
Expand All @@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
}

const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);
std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);

return env->NewStringUTF(formatted_chat.c_str());
}
Expand Down Expand Up @@ -364,7 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
jint top_k,
jfloat top_p,
jfloat min_p,
jfloat tfs_z,
jfloat xtc_threshold,
jfloat xtc_probability,
jfloat typical_p,
jint seed,
jobjectArray stop,
Expand All @@ -377,18 +392,18 @@ Java_com_rnllama_LlamaContext_doCompletion(

llama->rewind();

llama_reset_timings(llama->ctx);
//llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.seed = seed;
llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

llama->params.n_predict = n_predict;
llama->params.ignore_eos = ignore_eos;
llama->params.sparams.ignore_eos = ignore_eos;

auto & sparams = llama->params.sparams;
sparams.temp = temperature;
Expand All @@ -403,14 +418,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.top_k = top_k;
sparams.top_p = top_p;
sparams.min_p = min_p;
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.typ_p = typical_p;
sparams.n_probs = n_probs;
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
sparams.xtc_threshold = xtc_threshold;
sparams.xtc_probability = xtc_probability;

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
}

const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
Expand All @@ -424,9 +440,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_token tok = static_cast<llama_token>(doubleArray[0]);
if (tok >= 0 && tok < n_vocab) {
if (doubleArray[1] != 0) { // If the second element is not false (0)
sparams.logit_bias[tok] = doubleArray[1];
sparams.logit_bias[tok].bias = doubleArray[1];
} else {
sparams.logit_bias[tok] = -INFINITY;
sparams.logit_bias[tok].bias = -INFINITY;
}
}

Expand Down Expand Up @@ -460,7 +476,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
if (token_with_probs.tok == -1 || llama->incomplete) {
continue;
}
const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);

size_t pos = std::min(sent_count, llama->generated_text.size());

Expand Down Expand Up @@ -495,7 +511,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
Expand All @@ -512,7 +528,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
}
}

llama_print_timings(llama->ctx);
llama_perf_context_print(llama->ctx);
llama->is_predicting = false;

auto result = createWriteableMap(env);
Expand All @@ -527,16 +543,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
putString(env, result, "stopping_word", llama->stopping_word.c_str());
putInt(env, result, "tokens_cached", llama->n_past);

const auto timings = llama_get_timings(llama->ctx);
const auto timings_token = llama_perf_context(llama -> ctx);

auto timingsResult = createWriteableMap(env);
putInt(env, timingsResult, "prompt_n", timings.n_p_eval);
putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms);
putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval);
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
putInt(env, timingsResult, "predicted_n", timings.n_eval);
putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms);
putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval);
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval);
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);

putMap(env, result, "timings", timingsResult);

Expand Down Expand Up @@ -569,7 +586,7 @@ Java_com_rnllama_LlamaContext_tokenize(

const char *text_chars = env->GetStringUTFChars(text, nullptr);

const std::vector<llama_token> toks = llama_tokenize(
const std::vector<llama_token> toks = common_tokenize(
llama->ctx,
text_chars,
false
Expand Down Expand Up @@ -623,7 +640,7 @@ Java_com_rnllama_LlamaContext_embedding(

llama->rewind();

llama_reset_timings(llama->ctx);
llama_perf_context_reset(llama->ctx);

llama->params.prompt = text_chars;

Expand Down Expand Up @@ -681,9 +698,16 @@ Java_com_rnllama_LlamaContext_freeContext(
}
if (llama->ctx_sampling != nullptr)
{
llama_sampling_free(llama->ctx_sampling);
common_sampler_free(llama->ctx_sampling);
}
context_map.erase((long) llama->ctx);
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
UNUSED(env);
UNUSED(thiz);
llama_log_set(log_callback, NULL);
}

} // extern "C"
Loading

0 comments on commit 1ca3044

Please sign in to comment.