Skip to content

Commit

Permalink
fix: not init sampling before get embedding (#69)
Browse files Browse the repository at this point in the history
* fix: not init sampling before get embedding

* fix(android): embeddings array
  • Loading branch information
jhen0409 authored Jul 27, 2024
1 parent 7bfda3b commit 1088300
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
16 changes: 11 additions & 5 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public WritableMap completion(ReadableMap params) {
}
}

return doCompletion(
WritableMap result = doCompletion(
this.context,
// String prompt,
params.getString("prompt"),
Expand Down Expand Up @@ -193,6 +193,10 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
)
);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
return result;
}

public void stopCompletion() {
Expand All @@ -217,12 +221,14 @@ public String detokenize(ReadableArray tokens) {
return detokenize(this.context, toks);
}

public WritableMap embedding(String text) {
public WritableMap getEmbedding(String text) {
if (isEmbeddingEnabled(this.context) == false) {
throw new IllegalStateException("Embedding is not enabled");
}
WritableMap result = Arguments.createMap();
result.putArray("embedding", embedding(this.context, text));
WritableMap result = embedding(this.context, text);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
return result;
}

Expand Down Expand Up @@ -354,7 +360,7 @@ protected static native WritableMap doCompletion(
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableArray embedding(long contextPtr, String text);
protected static native WritableMap embedding(long contextPtr, String text);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
}
2 changes: 1 addition & 1 deletion android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ protected WritableMap doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
return context.embedding(text);
return context.getEmbedding(text);
} catch (Exception e) {
exception = e;
}
Expand Down
13 changes: 10 additions & 3 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,17 +581,24 @@ Java_com_rnllama_LlamaContext_embedding(
llama->params.prompt = text_chars;

llama->params.n_predict = 0;

auto result = createWriteableMap(env);
if (!llama->initSampling()) {
putString(env, result, "error", "Failed to initialize sampling");
return reinterpret_cast<jobject>(result);
}

llama->beginCompletion();
llama->loadPrompt();
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding();

jobject result = createWritableArray(env);

auto embeddings = createWritableArray(env);
for (const auto &val : embedding) {
pushDouble(env, result, (double) val);
pushDouble(env, embeddings, (double) val);
}
putArray(env, result, "embedding", embeddings);

env->ReleaseStringUTFChars(text, text_chars);
return result;
Expand Down
6 changes: 5 additions & 1 deletion ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,12 @@ - (NSArray *)embedding:(NSString *)text {
llama->params.prompt = [text UTF8String];

llama->params.n_predict = 0;
llama->loadPrompt();

if (!llama->initSampling()) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil];
}
llama->beginCompletion();
llama->loadPrompt();
llama->doCompletion();

std::vector<float> result = llama->getEmbedding();
Expand Down

0 comments on commit 1088300

Please sign in to comment.