diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index dad6fb5..69281b3 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -139,7 +139,7 @@ public WritableMap completion(ReadableMap params) { } } - return doCompletion( + WritableMap result = doCompletion( this.context, // String prompt, params.getString("prompt"), @@ -193,6 +193,10 @@ public WritableMap completion(ReadableMap params) { params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false ) ); + if (result.hasKey("error")) { + throw new IllegalStateException(result.getString("error")); + } + return result; } public void stopCompletion() { @@ -217,12 +221,14 @@ public String detokenize(ReadableArray tokens) { return detokenize(this.context, toks); } - public WritableMap embedding(String text) { + public WritableMap getEmbedding(String text) { if (isEmbeddingEnabled(this.context) == false) { throw new IllegalStateException("Embedding is not enabled"); } - WritableMap result = Arguments.createMap(); - result.putArray("embedding", embedding(this.context, text)); + WritableMap result = embedding(this.context, text); + if (result.hasKey("error")) { + throw new IllegalStateException(result.getString("error")); + } return result; } @@ -354,7 +360,7 @@ protected static native WritableMap doCompletion( protected static native WritableArray tokenize(long contextPtr, String text); protected static native String detokenize(long contextPtr, int[] tokens); protected static native boolean isEmbeddingEnabled(long contextPtr); - protected static native WritableArray embedding(long contextPtr, String text); + protected static native WritableMap embedding(long contextPtr, String text); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); protected static native void freeContext(long contextPtr); } diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 0b673d9..430eae7 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -297,7 +297,7 @@ protected WritableMap doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } - return context.embedding(text); + return context.getEmbedding(text); } catch (Exception e) { exception = e; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 558eeb6..3de36f3 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -581,17 +581,24 @@ Java_com_rnllama_LlamaContext_embedding( llama->params.prompt = text_chars; llama->params.n_predict = 0; + + auto result = createWriteableMap(env); + if (!llama->initSampling()) { + putString(env, result, "error", "Failed to initialize sampling"); + return reinterpret_cast(result); + } + llama->beginCompletion(); llama->loadPrompt(); llama->doCompletion(); std::vector embedding = llama->getEmbedding(); - jobject result = createWritableArray(env); - + auto embeddings = createWritableArray(env); for (const auto &val : embedding) { - pushDouble(env, result, (double) val); + pushDouble(env, embeddings, (double) val); } + putArray(env, result, "embedding", embeddings); env->ReleaseStringUTFChars(text, text_chars); return result; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 58819f2..cee2367 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -365,8 +365,12 @@ - (NSArray *)embedding:(NSString *)text { llama->params.prompt = [text UTF8String]; llama->params.n_predict = 0; - llama->loadPrompt(); + + if (!llama->initSampling()) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil]; + } llama->beginCompletion(); + llama->loadPrompt(); llama->doCompletion(); std::vector result = llama->getEmbedding();