Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync llama.cpp #79

Merged
merged 7 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ include_directories(${RNLLAMA_LIB_DIR})

set(
SOURCE_FILES
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/log.cpp

${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/ggml-alloc.c
${RNLLAMA_LIB_DIR}/ggml-backend.c
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
${RNLLAMA_LIB_DIR}/ggml.c
${RNLLAMA_LIB_DIR}/ggml-quants.c
${RNLLAMA_LIB_DIR}/common.cpp
${RNLLAMA_LIB_DIR}/grammar-parser.cpp
${RNLLAMA_LIB_DIR}/json.hpp
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
${RNLLAMA_LIB_DIR}/sampling.cpp
${RNLLAMA_LIB_DIR}/unicode-data.cpp
${RNLLAMA_LIB_DIR}/unicode.cpp
${RNLLAMA_LIB_DIR}/llama.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/sgemm.cpp
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/rn-llama.hpp
Expand Down Expand Up @@ -65,10 +67,20 @@ build_library("rnllama" "")

if (${ANDROID_ABI} STREQUAL "arm64-v8a")
# ARM64 targets
build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
build_library("rnllama_v8" "-march=armv8-a")

# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
# llama.cpp will deal with the cpu features
# build_library("rnllama_v8_7" "-march=armv8.7-a")
# TODO: Add support runtime check for cpu features
# At the moment runtime check is failing.

elseif (${ANDROID_ABI} STREQUAL "x86_64")
# x86_64 target
build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
Expand Down
60 changes: 46 additions & 14 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
if (!params.hasKey("model")) {
throw new IllegalArgumentException("Missing required parameter: model");
}
Log.d(NAME, "Setting log callback");
logToAndroid();
this.id = id;
this.context = initContext(
// String model,
Expand All @@ -53,6 +55,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
// boolean use_mmap,
params.hasKey("use_mmap") ? params.getBoolean("use_mmap") : true,
//boolean vocab_only,
params.hasKey("vocab_only") ? params.getBoolean("vocab_only") : false,
// String lora,
params.hasKey("lora") ? params.getString("lora") : "",
// float lora_scaled,
Expand Down Expand Up @@ -181,6 +185,10 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
// float min_p,
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
// float xtc_threshold,
params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
// float xtc_probability,
params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
// float tfs_z,
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
// float typical_p,
Expand Down Expand Up @@ -248,16 +256,34 @@ public void release() {

static {
Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
if (LlamaContext.isArm64V8a()) {
String cpuFeatures = LlamaContext.getCpuFeatures();
Log.d(NAME, "CPU features: " + cpuFeatures);

boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");

if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
String cpuFeatures = LlamaContext.getCpuFeatures();
Log.d(NAME, "CPU features: " + cpuFeatures);
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
boolean hasSve = cpuFeatures.contains("sve");
boolean hasI8mm = cpuFeatures.contains("i8mm");
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
Log.d(NAME, "- hasFp16: " + hasFp16);
Log.d(NAME, "- hasDotProd: " + hasDotProd);
Log.d(NAME, "- hasSve: " + hasSve);
Log.d(NAME, "- hasI8mm: " + hasI8mm);
Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);

// TODO: Add runtime check for cpu features
if (LlamaContext.isArm64V8a()) {
if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
} else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
} else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
} else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
System.loadLibrary("rnllama_v8_4_fp16_dotprod");
} else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
Expand All @@ -270,14 +296,16 @@ public void release() {
Log.d(NAME, "Loading librnllama_v8.so");
System.loadLibrary("rnllama_v8");
}
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
// System.loadLibrary("rnllama_v8_7");
} else if (LlamaContext.isX86_64()) {
Log.d(NAME, "Loading librnllama_x86_64.so");
System.loadLibrary("rnllama_x86_64");
Log.d(NAME, "Loading librnllama_x86_64.so");
System.loadLibrary("rnllama_x86_64");
} else {
Log.d(NAME, "Loading default librnllama.so");
System.loadLibrary("rnllama");
Log.d(NAME, "Loading default librnllama.so");
System.loadLibrary("rnllama");
}
}
}

private static boolean isArm64V8a() {
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
Expand Down Expand Up @@ -316,6 +344,7 @@ protected static native long initContext(
int n_gpu_layers, // TODO: Support this
boolean use_mlock,
boolean use_mmap,
boolean vocab_only,
String lora,
float lora_scaled,
float rope_freq_base,
Expand Down Expand Up @@ -357,6 +386,8 @@ protected static native WritableMap doCompletion(
int top_k,
float top_p,
float min_p,
float xtc_threshold,
float xtc_probability,
float tfs_z,
float typical_p,
int seed,
Expand All @@ -373,4 +404,5 @@ protected static native WritableMap doCompletion(
protected static native WritableMap embedding(long contextPtr, String text);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
}
82 changes: 53 additions & 29 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// #include <android/asset_manager_jni.h>
#include <android/log.h>
#include <cstdlib>
#include <ctime>
#include <sys/sysinfo.h>
#include <string>
#include <thread>
Expand All @@ -21,6 +22,13 @@ static inline int min(int a, int b) {
return (a < b) ? a : b;
}

static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
}

extern "C" {

// Method to create WritableMap
Expand Down Expand Up @@ -139,14 +147,20 @@ Java_com_rnllama_LlamaContext_initContext(
jint n_gpu_layers, // TODO: Support this
jboolean use_mlock,
jboolean use_mmap,
jboolean vocab_only,
jstring lora_str,
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale
) {
UNUSED(thiz);

gpt_params defaultParams;
common_params defaultParams;

defaultParams.vocab_only = vocab_only;
if(vocab_only) {
defaultParams.warmup = false;
}

const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
defaultParams.model = model_path_chars;
Expand All @@ -159,7 +173,7 @@ Java_com_rnllama_LlamaContext_initContext(
int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

defaultParams.n_gpu_layers = n_gpu_layers;

Expand Down Expand Up @@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

std::vector<llama_chat_msg> chat;
std::vector<common_chat_msg> chat;

int messages_len = env->GetArrayLength(messages);
for (int i = 0; i < messages_len; i++) {
Expand All @@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
}

const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);
std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);

return env->NewStringUTF(formatted_chat.c_str());
}
Expand Down Expand Up @@ -364,7 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
jint top_k,
jfloat top_p,
jfloat min_p,
jfloat tfs_z,
jfloat xtc_threshold,
jfloat xtc_probability,
jfloat typical_p,
jint seed,
jobjectArray stop,
Expand All @@ -377,18 +392,18 @@ Java_com_rnllama_LlamaContext_doCompletion(

llama->rewind();

llama_reset_timings(llama->ctx);
//llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.seed = seed;
llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

llama->params.n_predict = n_predict;
llama->params.ignore_eos = ignore_eos;
llama->params.sparams.ignore_eos = ignore_eos;

auto & sparams = llama->params.sparams;
sparams.temp = temperature;
Expand All @@ -403,14 +418,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.top_k = top_k;
sparams.top_p = top_p;
sparams.min_p = min_p;
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.typ_p = typical_p;
sparams.n_probs = n_probs;
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
sparams.xtc_threshold = xtc_threshold;
sparams.xtc_probability = xtc_probability;

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
}

const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
Expand All @@ -424,9 +440,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_token tok = static_cast<llama_token>(doubleArray[0]);
if (tok >= 0 && tok < n_vocab) {
if (doubleArray[1] != 0) { // If the second element is not false (0)
sparams.logit_bias[tok] = doubleArray[1];
sparams.logit_bias[tok].bias = doubleArray[1];
} else {
sparams.logit_bias[tok] = -INFINITY;
sparams.logit_bias[tok].bias = -INFINITY;
}
}

Expand Down Expand Up @@ -460,7 +476,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
if (token_with_probs.tok == -1 || llama->incomplete) {
continue;
}
const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);

size_t pos = std::min(sent_count, llama->generated_text.size());

Expand Down Expand Up @@ -495,7 +511,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
Expand All @@ -512,7 +528,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
}
}

llama_print_timings(llama->ctx);
llama_perf_context_print(llama->ctx);
llama->is_predicting = false;

auto result = createWriteableMap(env);
Expand All @@ -527,16 +543,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
putString(env, result, "stopping_word", llama->stopping_word.c_str());
putInt(env, result, "tokens_cached", llama->n_past);

const auto timings = llama_get_timings(llama->ctx);
const auto timings_token = llama_perf_context(llama -> ctx);

auto timingsResult = createWriteableMap(env);
putInt(env, timingsResult, "prompt_n", timings.n_p_eval);
putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms);
putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval);
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
putInt(env, timingsResult, "predicted_n", timings.n_eval);
putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms);
putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval);
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval);
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);

putMap(env, result, "timings", timingsResult);

Expand Down Expand Up @@ -569,7 +586,7 @@ Java_com_rnllama_LlamaContext_tokenize(

const char *text_chars = env->GetStringUTFChars(text, nullptr);

const std::vector<llama_token> toks = llama_tokenize(
const std::vector<llama_token> toks = common_tokenize(
llama->ctx,
text_chars,
false
Expand Down Expand Up @@ -623,7 +640,7 @@ Java_com_rnllama_LlamaContext_embedding(

llama->rewind();

llama_reset_timings(llama->ctx);
llama_perf_context_reset(llama->ctx);

llama->params.prompt = text_chars;

Expand Down Expand Up @@ -681,9 +698,16 @@ Java_com_rnllama_LlamaContext_freeContext(
}
if (llama->ctx_sampling != nullptr)
{
llama_sampling_free(llama->ctx_sampling);
common_sampler_free(llama->ctx_sampling);
}
context_map.erase((long) llama->ctx);
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
UNUSED(env);
UNUSED(thiz);
llama_log_set(log_callback, NULL);
}

} // extern "C"
Loading
Loading