Skip to content

Commit

Permalink
feat: add detokenize method (#12)
Browse files Browse the repository at this point in the history
* feat: add detokenize method

* feat: update jest mock
  • Loading branch information
jhen0409 authored Aug 29, 2023
1 parent 7cde772 commit 13b0a15
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 0 deletions.
9 changes: 9 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ public WritableMap tokenize(String text) {
return result;
}

public String detokenize(ReadableArray tokens) {
int[] toks = new int[tokens.size()];
for (int i = 0; i < tokens.size(); i++) {
toks[i] = (int) tokens.getDouble(i);
}
return detokenize(this.context, toks);
}

public WritableMap embedding(String text) {
if (isEmbeddingEnabled(this.context) == false) {
throw new IllegalStateException("Embedding is not enabled");
Expand Down Expand Up @@ -253,6 +261,7 @@ protected static native WritableMap doCompletion(
protected static native void stopCompletion(long contextPtr);
protected static native boolean isPredicting(long contextPtr);
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableArray embedding(long contextPtr, String text);
protected static native void freeContext(long contextPtr);
Expand Down
20 changes: 20 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,26 @@ Java_com_rnllama_LlamaContext_tokenize(
return result;
}

JNIEXPORT jstring JNICALL
Java_com_rnllama_LlamaContext_detokenize(
JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

jsize tokens_len = env->GetArrayLength(tokens);
jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
std::vector<llama_token> toks;
for (int i = 0; i < tokens_len; i++) {
toks.push_back(tokens_ptr[i]);
}

auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());

env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);

return env->NewStringUTF(text.c_str());
}

JNIEXPORT jboolean JNICALL
Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
JNIEnv *env, jobject thiz, jlong context_ptr) {
Expand Down
32 changes: 32 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.LifecycleEventListener;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.module.annotations.ReactModule;
Expand Down Expand Up @@ -182,6 +183,37 @@ protected void onPostExecute(WritableMap result) {
}.execute();
}

@ReactMethod
public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.detokenize(tokens);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(String result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.execute();
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
final int contextId = (int) id;
Expand Down
32 changes: 32 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.LifecycleEventListener;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.module.annotations.ReactModule;
Expand Down Expand Up @@ -183,6 +184,37 @@ protected void onPostExecute(WritableMap result) {
}.execute();
}

@ReactMethod
public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.detokenize(tokens);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(String result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.execute();
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
final int contextId = (int) id;
Expand Down
5 changes: 5 additions & 0 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ export default function App() {
// await context?.embedding(prompt).then((result) => {
// console.log('Embedding:', result)
// })

// Test detokenize
// await context?.detokenize(tokens).then((result) => {
// console.log('Detokenize:', result)
// })
}

let grammar
Expand Down
13 changes: 13 additions & 0 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ - (NSArray *)supportedEvents {
[tokens release];
}

RCT_EXPORT_METHOD(detokenize:(double)contextId
tokens:(NSArray *)tokens
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
if (context == nil) {
reject(@"llama_error", @"Context not found", nil);
return;
}
resolve([context detokenize:tokens]);
}

RCT_EXPORT_METHOD(embedding:(double)contextId
text:(NSString *)text
withResolver:(RCTPromiseResolveBlock)resolve
Expand Down
2 changes: 2 additions & 0 deletions ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
- (NSDictionary *)completion:(NSDictionary *)params onToken:(void (^)(NSMutableDictionary *tokenResult))onToken;
- (void)stopCompletion;
- (NSArray *)tokenize:(NSString *)text;
- (NSString *)detokenize:(NSArray *)tokens;
- (NSArray *)embedding:(NSString *)text;

- (void)invalidate;

Expand Down
9 changes: 9 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,15 @@ - (NSArray *)tokenize:(NSString *)text {
return result;
}

- (NSString *)detokenize:(NSArray *)tokens {
std::vector<llama_token> toks;
for (NSNumber *tok in tokens) {
toks.push_back([tok intValue]);
}
const std::string text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
return [NSString stringWithUTF8String:text.c_str()];
}

- (NSArray *)embedding:(NSString *)text {
if (llama->params.embedding != true) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not enabled" userInfo:nil];
Expand Down
4 changes: 4 additions & 0 deletions jest/mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ if (!NativeModules.RNLlama) {

stopCompletion: jest.fn(),

tokenize: jest.fn(async () => []),
detokenize: jest.fn(async () => ''),
embedding: jest.fn(async () => []),

releaseContext: jest.fn(() => Promise.resolve()),
releaseAllContexts: jest.fn(() => Promise.resolve()),

Expand Down
1 change: 1 addition & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export interface Spec extends TurboModule {
completion(contextId: number, params: NativeCompletionParams): Promise<NativeCompletionResult>;
stopCompletion(contextId: number): Promise<void>;
tokenize(contextId: number, text: string): Promise<NativeTokenizeResult>;
detokenize(contextId: number, tokens: number[]): Promise<string>;
embedding(contextId: number, text: string): Promise<NativeEmbeddingResult>;
releaseContext(contextId: number): Promise<void>;

Expand Down
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ export class LlamaContext {
return RNLlama.tokenize(this.id, text)
}

detokenize(tokens: number[]): Promise<string> {
return RNLlama.detokenize(this.id, tokens)
}

embedding(text: string): Promise<NativeEmbeddingResult> {
return RNLlama.embedding(this.id, text)
}
Expand Down

0 comments on commit 13b0a15

Please sign in to comment.