Skip to content

Commit

Permalink
feat(android, cpp): impl stopCompletion & tokenize
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Aug 26, 2023
1 parent 55ec4c8 commit 3045de4
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 19 deletions.
27 changes: 16 additions & 11 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ public class LlamaContext {
private int jobId = -1;
private DeviceEventManagerModule.RCTDeviceEventEmitter eventEmitter;

private boolean isPredicting = false;
private boolean isInterrupted = false;

public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap params) {
if (!params.hasKey("model")) {
throw new IllegalArgumentException("Missing required parameter: model");
Expand Down Expand Up @@ -90,16 +87,11 @@ void onPartialCompletion(WritableMap tokenResult) {
}

public WritableMap completion(ReadableMap params) {

Log.i(NAME, "completion: " + this.context);

isPredicting = true;
isInterrupted = false;
if (!params.hasKey("prompt")) {
throw new IllegalArgumentException("Missing required parameter: prompt");
}

WritableMap result = doCompletion(
return doCompletion(
this.context,
// String prompt,
params.getString("prompt"),
Expand Down Expand Up @@ -144,12 +136,22 @@ public WritableMap completion(ReadableMap params) {
// PartialCompletionCallback partial_completion_callback
new PartialCompletionCallback(this)
);
}

isPredicting = false;
public void stopCompletion() {
stopCompletion(this.context);
}

return result;
public boolean isPredicting() {
return isPredicting(this.context);
}

public WritableMap tokenize(String text) {
WritableMap result = Arguments.createMap();
WritableArray tokens = tokenize(this.context, text);
result.putArray("tokens", tokens);
return result;
}

public void release() {
freeContext(context);
Expand Down Expand Up @@ -257,5 +259,8 @@ protected static native WritableMap doCompletion(
int[][] logit_bias,
PartialCompletionCallback partial_completion_callback
);
protected static native void stopCompletion(long contextPtr);
protected static native boolean isPredicting(long contextPtr);
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native void freeContext(long contextPtr);
}
51 changes: 51 additions & 0 deletions android/src/main/jni/rnllama/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ static inline jobject createWritableArray(JNIEnv *env) {
return map;
}

// Method to push int into WritableArray
static inline void pushInt(JNIEnv *env, jobject arr, int value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");

env->CallVoidMethod(arr, pushIntMethod, value);
}


// Method to push WritableMap into WritableArray
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
Expand Down Expand Up @@ -337,6 +346,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
}

llama_print_timings(llama->ctx);
llama->is_predicting = false;

auto result = createWriteableMap(env);
putString(env, result, "text", llama->generated_text.c_str());
Expand Down Expand Up @@ -366,6 +376,47 @@ Java_com_rnllama_LlamaContext_doCompletion(
return reinterpret_cast<jobject>(result);
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_stopCompletion(
JNIEnv *env, jobject thiz, jlong context_ptr) {
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
llama->is_interrupted = true;
}

JNIEXPORT jboolean JNICALL
Java_com_rnllama_LlamaContext_isPredicting(
JNIEnv *env, jobject thiz, jlong context_ptr) {
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
return llama->is_predicting;
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_tokenize(
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

const char *text_chars = env->GetStringUTFChars(text, nullptr);

const std::vector<llama_token> toks = llama_tokenize(
llama->ctx,
text_chars,
false
);

jobject result = createWritableArray(env);
for (const auto &tok : toks) {
pushInt(env, result, tok);
}

env->ReleaseStringUTFChars(text, text_chars);
return result;
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_freeContext(
JNIEnv *env, jobject thiz, jlong context_ptr) {
Expand Down
68 changes: 61 additions & 7 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ protected WritableMap doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context " + id + " not found");
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
WritableMap result = context.completion(params);
promise.resolve(result);
return result;
} catch (Exception e) {
exception = e;
}
Expand All @@ -118,14 +121,65 @@ protected void onPostExecute(WritableMap result) {
}

@ReactMethod
public void stopCompletion(double contextId, final Promise promise) {
// TODO: implement
}
public void stopCompletion(double id, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
context.stopCompletion();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.execute();
}

@ReactMethod
public void tokenize(double contextId, final String text, final Promise promise) {
// TODO: implement
public void tokenize(double id, final String text, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.tokenize(text);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(WritableMap result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.execute();
}

@ReactMethod
Expand Down
1 change: 0 additions & 1 deletion ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ - (NSArray *)tokenize:(NSString *)text {
const std::vector<llama_token> toks = llama_tokenize(llama->ctx, [text UTF8String], false);
NSMutableArray *result = [[NSMutableArray alloc] init];
for (llama_token tok : toks) {
printf("%d\n", tok);
[result addObject:@(tok)];
}
return result;
Expand Down

0 comments on commit 3045de4

Please sign in to comment.