diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index f44cbbd..c7efc9d 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -5,6 +5,7 @@ #include #include "common.h" #include "llama.h" +#include "grammar-parser.h" namespace rnllama { @@ -141,6 +142,9 @@ struct llama_rn_context llama_context *ctx = nullptr; gpt_params params; + grammar_parser::parse_state parsed_grammar; + llama_grammar *grammar = nullptr; + bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -165,6 +169,7 @@ struct llama_rn_context void rewind() { params.antiprompt.clear(); + params.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -176,9 +181,13 @@ struct llama_rn_context stopped_limit = false; stopping_word = ""; multibyte_pending = 0; - n_remain = 0; n_past = 0; + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + } } bool loadModel(gpt_params ¶ms_) @@ -196,6 +205,31 @@ struct llama_rn_context return true; } + bool loadGrammar() + { + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + LOG_ERROR("grammar parse error, grammar: %s", params.grammar.c_str()); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + LOG_WARNING("EOS token is disabled, which will cause most grammars to fail"); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + return true; + } + void loadPrompt() { params.prompt.insert(0, 1, ' '); // always add a first space @@ -355,6 +389,10 @@ struct llama_rn_context logits[llama_token_nl()] = nl_logit; } + if (grammar != nullptr) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling @@ -392,6 +430,10 @@ struct llama_rn_context } } + if (grammar != nullptr) { + llama_grammar_accept_token(ctx, grammar, result.tok); + } + for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) { result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); diff --git a/example/App.tsx b/example/App.tsx index 46a870b..5ed6d6e 100644 --- a/example/App.tsx +++ b/example/App.tsx @@ -91,7 +91,7 @@ export default function App() { initLlama({ model: file.uri, use_mlock: true, - n_gpu_layers: 1, // > 0: enable metal + n_gpu_layers: 0, // > 0: enable metal }) .then((ctx) => { setContext(ctx) diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index d01f168..eca69cb 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -132,6 +132,10 @@ - (NSDictionary *)completion:(NSDictionary *)params self->llama->params.prompt = [prompt UTF8String]; + if (params[@"grammar"]) { + self->llama->params.grammar = [params[@"grammar"] UTF8String]; + } + if (params[@"temperature"]) self->llama->params.temp = [params[@"temperature"] doubleValue]; if (params[@"n_threads"]) { @@ -188,6 +192,10 @@ - (NSDictionary *)completion:(NSDictionary *)params } } } + + if (!self->llama->loadGrammar()) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to load grammar" userInfo:nil]; + } self->llama->loadPrompt(); self->llama->beginCompletion(); @@ -278,6 +286,9 @@ - (void)stopCompletion { } - (void)invalidate { + if (self->llama->grammar != nullptr) { + llama_grammar_free(self->llama->grammar); + } delete self->llama; // llama_backend_free(); diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 033004d..e09e858 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -17,6 +17,8 @@ cp ./llama.cpp/k_quants.h ./cpp/k_quants.h cp ./llama.cpp/k_quants.c ./cpp/k_quants.c cp ./llama.cpp/examples/common.h ./cpp/common.h cp ./llama.cpp/examples/common.cpp ./cpp/common.cpp +cp ./llama.cpp/examples/grammar-parser.h ./cpp/grammar-parser.h +cp ./llama.cpp/examples/grammar-parser.cpp ./cpp/grammar-parser.cpp # List of files to process files=( diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 57d2b41..52462b1 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -27,6 +27,7 @@ export type NativeContextParams = { export type NativeCompletionParams = { prompt: string + grammar?: string stop?: Array // -> antiprompt n_predict?: number