From 8a4b863cf154f7b5a935a68098cb7dec239bb6ed Mon Sep 17 00:00:00 2001 From: jhen Date: Tue, 29 Aug 2023 09:47:57 +0800 Subject: [PATCH] fix: send rest of content on stop --- android/src/main/jni.cpp | 56 +++++++++++++++++++++++----------------- ios/RNLlamaContext.mm | 46 +++++++++++++++++++-------------- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index cfcd352..121cc11 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -317,51 +317,59 @@ Java_com_rnllama_LlamaContext_doCompletion( while (llama->has_next_token && !llama->is_interrupted) { const rnllama::completion_token_output token_with_probs = llama->doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama->ctx, token_with_probs.tok); - if (llama->multibyte_pending > 0) { + if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) { continue; } + const std::string token_text = llama_token_to_str(llama->ctx, token_with_probs.tok); size_t pos = std::min(sent_count, llama->generated_text.size()); const std::string str_test = llama->generated_text.substr(pos); + bool is_stop_full = false; size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL); if (stop_pos != std::string::npos) { + is_stop_full = true; llama->generated_text.erase( llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end()); pos = std::min(sent_count, llama->generated_text.size()); } else { + is_stop_full = false; stop_pos = llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_PARTIAL); } - const std::string to_send = stop_pos == std::string::npos ? - llama->generated_text.substr(pos, std::string::npos) : - ""; // just don't send anything if we're not done - sent_count += to_send.size(); + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama->has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama->generated_text.substr(pos, std::string::npos); - std::vector probs_output = {}; + sent_count += to_send.size(); - auto tokenResult = createWriteableMap(env); - putString(env, tokenResult, "token", to_send.c_str()); - - if (llama->params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - - putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output)); - } + std::vector probs_output = {}; + + auto tokenResult = createWriteableMap(env); + putString(env, tokenResult, "token", to_send.c_str()); + + if (llama->params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; - jclass cb_class = env->GetObjectClass(partial_completion_callback); - jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V"); - env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult); + putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output)); + } + + jclass cb_class = env->GetObjectClass(partial_completion_callback); + jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V"); + env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult); + } } llama_print_timings(llama->ctx); diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 62eeded..94ada62 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -204,49 +204,57 @@ - (NSDictionary *)completion:(NSDictionary *)params while (llama->has_next_token && !llama->is_interrupted) { const rnllama::completion_token_output token_with_probs = llama->doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama->ctx, token_with_probs.tok); - if (llama->multibyte_pending > 0) { + if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) { continue; } + const std::string token_text = llama_token_to_str(llama->ctx, token_with_probs.tok); size_t pos = std::min(sent_count, llama->generated_text.size()); const std::string str_test = llama->generated_text.substr(pos); + bool is_stop_full = false; size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL); if (stop_pos != std::string::npos) { + is_stop_full = true; llama->generated_text.erase( llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end()); pos = std::min(sent_count, llama->generated_text.size()); } else { + is_stop_full = false; stop_pos = llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_PARTIAL); } - const std::string to_send = stop_pos == std::string::npos ? - llama->generated_text.substr(pos, std::string::npos) : - ""; // just don't send anything if we're not done - sent_count += to_send.size(); + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama->has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama->generated_text.substr(pos, std::string::npos); - std::vector probs_output = {}; + sent_count += to_send.size(); - NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init]; - tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()]; + std::vector probs_output = {}; - if (llama->params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); + NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init]; + tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()]; + + if (llama->params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + + tokenResult[@"completion_probabilities"] = [self tokenProbsToDict:probs_output]; } - sent_token_probs_index = probs_stop_pos; - tokenResult[@"completion_probabilities"] = [self tokenProbsToDict:probs_output]; + onToken(tokenResult); } - - onToken(tokenResult); } llama_print_timings(llama->ctx);