Skip to content

Commit

Permalink
feat(llama.cpp): Add support to grammar triggers (#4733)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Feb 2, 2025
1 parent d79f02e commit 1d6afbd
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
7 changes: 7 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ message Reply {
double timing_token_generation = 5;
}

message GrammarTrigger {
string word = 1;
bool at_start = 2;
}

message ModelOptions {
string Model = 1;
int32 ContextSize = 2;
Expand Down Expand Up @@ -247,6 +252,8 @@ message ModelOptions {

string CacheTypeKey = 63;
string CacheTypeValue = 64;

repeated GrammarTrigger GrammarTriggers = 65;
}

message Result {
Expand Down
20 changes: 20 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ struct llama_server_context
bool add_bos_token = true;
bool has_eos_token = true;

bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words;

int32_t n_ctx; // total context for all clients / slots

// system prompt
Expand Down Expand Up @@ -706,6 +709,8 @@ struct llama_server_context
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot->sparams.grammar_trigger_words = grammar_trigger_words;
slot->sparams.grammar_lazy = grammar_lazy;

if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
Expand Down Expand Up @@ -2374,6 +2379,21 @@ static void params_parse(const backend::ModelOptions* request,
if ( request->ropefreqscale() != 0.0f ) {
params.rope_freq_scale = request->ropefreqscale();
}

if (request->grammartriggers_size() > 0) {
LOG_INFO("configuring grammar triggers", {});
llama.grammar_lazy = true;
for (int i = 0; i < request->grammartriggers_size(); i++) {
common_grammar_trigger trigger;
trigger.word = request->grammartriggers(i).word();
trigger.at_start = request->grammartriggers(i).at_start();
llama.grammar_trigger_words.push_back(trigger);
LOG_INFO("grammar trigger", {
{ "word", trigger.word },
{ "at_start", trigger.at_start }
});
}
}
}


Expand Down
10 changes: 10 additions & 0 deletions core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,19 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
nGPULayers = *c.NGPULayers
}

triggers := make([]*pb.GrammarTrigger, 0)
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
triggers = append(triggers, &pb.GrammarTrigger{
Word: t.Word,
AtStart: t.AtStart,
})

}

return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
GrammarTriggers: triggers,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.CFGScale,
LoraAdapter: c.LoraAdapter,
Expand Down
10 changes: 9 additions & 1 deletion pkg/functions/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ type GrammarConfig struct {
// SchemaType can be configured to use a specific schema type to force the grammar
// available : json, llama3.1
SchemaType string `yaml:"schema_type"`

GrammarTriggers []GrammarTrigger `yaml:"triggers"`
}

type GrammarTrigger struct {
// Trigger is the string that triggers the grammar
Word string `yaml:"word"`
AtStart bool `yaml:"at_start"`
}

// FunctionsConfig is the configuration for the tool/function call.
Expand Down Expand Up @@ -361,6 +369,6 @@ func ParseFunctionCallArgs(functionArguments string, functionConfig FunctionsCon
}

jsonBytes, _ := json.Marshal(args)

return string(jsonBytes)
}

0 comments on commit 1d6afbd

Please sign in to comment.