Skip to content

Commit

Permalink
fix: send rest of content on stop
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Aug 29, 2023
1 parent 17f24c7 commit 8a4b863
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 43 deletions.
56 changes: 32 additions & 24 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,51 +317,59 @@ Java_com_rnllama_LlamaContext_doCompletion(

while (llama->has_next_token && !llama->is_interrupted) {
const rnllama::completion_token_output token_with_probs = llama->doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama->ctx, token_with_probs.tok);
if (llama->multibyte_pending > 0) {
if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
continue;
}
const std::string token_text = llama_token_to_str(llama->ctx, token_with_probs.tok);

size_t pos = std::min(sent_count, llama->generated_text.size());

const std::string str_test = llama->generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama->generated_text.erase(
llama->generated_text.begin() + pos + stop_pos,
llama->generated_text.end());
pos = std::min(sent_count, llama->generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
rnllama::STOP_PARTIAL);
}

const std::string to_send = stop_pos == std::string::npos ?
llama->generated_text.substr(pos, std::string::npos) :
""; // just don't send anything if we're not done
sent_count += to_send.size();
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama->has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama->generated_text.substr(pos, std::string::npos);

std::vector<rnllama::completion_token_output> probs_output = {};
sent_count += to_send.size();

auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;

putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
}
std::vector<rnllama::completion_token_output> probs_output = {};

auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;

jclass cb_class = env->GetObjectClass(partial_completion_callback);
jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
}

jclass cb_class = env->GetObjectClass(partial_completion_callback);
jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
}
}

llama_print_timings(llama->ctx);
Expand Down
46 changes: 27 additions & 19 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -204,49 +204,57 @@ - (NSDictionary *)completion:(NSDictionary *)params

while (llama->has_next_token && !llama->is_interrupted) {
const rnllama::completion_token_output token_with_probs = llama->doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama->ctx, token_with_probs.tok);
if (llama->multibyte_pending > 0) {
if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
continue;
}
const std::string token_text = llama_token_to_str(llama->ctx, token_with_probs.tok);

size_t pos = std::min(sent_count, llama->generated_text.size());

const std::string str_test = llama->generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama->generated_text.erase(
llama->generated_text.begin() + pos + stop_pos,
llama->generated_text.end());
pos = std::min(sent_count, llama->generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
rnllama::STOP_PARTIAL);
}

const std::string to_send = stop_pos == std::string::npos ?
llama->generated_text.substr(pos, std::string::npos) :
""; // just don't send anything if we're not done
sent_count += to_send.size();
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama->has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama->generated_text.substr(pos, std::string::npos);

std::vector<rnllama::completion_token_output> probs_output = {};
sent_count += to_send.size();

NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init];
tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()];
std::vector<rnllama::completion_token_output> probs_output = {};

if (llama->params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init];
tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()];

if (llama->params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;

tokenResult[@"completion_probabilities"] = [self tokenProbsToDict:probs_output];
}
sent_token_probs_index = probs_stop_pos;

tokenResult[@"completion_probabilities"] = [self tokenProbsToDict:probs_output];
onToken(tokenResult);
}

onToken(tokenResult);
}

llama_print_timings(llama->ctx);
Expand Down

0 comments on commit 8a4b863

Please sign in to comment.