Skip to content

Commit

Permalink
add createChatCompletion (#140)
Browse files Browse the repository at this point in the history
* add formatChat and createChatCompletion

* bump llama.cpp upstream source code

* add test for createChatCompletion

* v2.1.0
  • Loading branch information
ngxson authored Dec 6, 2024
1 parent 48e36d1 commit ee31d9f
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 15 deletions.
36 changes: 32 additions & 4 deletions actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ json action_sampling_init(app_t &app, json &body)
{
llama_token token = item["token"];
float bias = item["bias"];
sparams.logit_bias.push_back({ token, bias });
sparams.logit_bias.push_back({token, bias});
}
}
// maybe free before creating a new one
Expand All @@ -363,7 +363,7 @@ json action_sampling_init(app_t &app, json &body)
json action_get_vocab(app_t &app, json &body)
{
int32_t max_tokens = llama_n_vocab(app.model);
std::vector<std::vector<unsigned int> > vocab(max_tokens);
std::vector<std::vector<unsigned int>> vocab(max_tokens);
for (int32_t id = 0; id < max_tokens; id++)
{
std::string token_as_str = common_token_to_piece(app.ctx, id);
Expand Down Expand Up @@ -429,8 +429,8 @@ json action_decode(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
bool skip_logits = body.contains("skip_logits")
? body.at("skip_logits").get<bool>()
: false;
? body.at("skip_logits").get<bool>()
: false;
size_t i = 0;
common_batch_clear(app.batch);
for (auto id : tokens_list)
Expand Down Expand Up @@ -587,6 +587,34 @@ json action_embeddings(app_t &app, json &body)
};
}

// apply chat template
json action_chat_format(app_t &app, json &body)
{
std::string tmpl = body.contains("tmpl") ? body["tmpl"] : "";
bool add_ass = body.contains("add_ass") ? body.at("add_ass").get<bool>() : false;
if (!body.contains("messages"))
{
return json{{"error", "messages is required"}};
}
std::vector<common_chat_msg> chat;
for (auto &item : body["messages"])
{
chat.push_back({item["role"], item["content"]});
}
try
{
std::string formatted_chat = common_chat_apply_template(app.model, tmpl, chat, add_ass);
return json{
{"success", true},
{"formatted_chat", formatted_chat},
};
}
catch (const std::exception &e)
{
return json{{"error", e.what()}};
}
}

// remove tokens in kv, for context-shifting
json action_kv_remove(app_t &app, json &body)
{
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Submodule llama.cpp updated 97 files
+22 −9 .devops/full.Dockerfile
+11 −5 .devops/llama-cli.Dockerfile
+10 −6 .devops/llama-server.Dockerfile
+1 −7 .github/pull_request_template.md
+29 −134 .github/workflows/build.yml
+16 −10 .github/workflows/server.yml
+4 −0 .gitignore
+0 −4 CMakeLists.txt
+3 −0 CODEOWNERS
+4 −2 CONTRIBUTING.md
+14 −9 Makefile
+1 −1 Package.swift
+4 −1 ci/run.sh
+18 −4 common/arg.cpp
+2 −1 common/common.h
+36 −22 convert_hf_to_gguf.py
+2 −1 convert_hf_to_gguf_update.py
+0 −7 docs/backend/BLIS.md
+97 −151 docs/build.md
+0 −61 examples/base-translate.sh
+1 −4 examples/convert-llama2c-to-ggml/README.md
+1 −1 examples/deprecation-warning/deprecation-warning.cpp
+0 −2 examples/imatrix/README.md
+1 −1 examples/infill/README.md
+9 −0 examples/llava/clip.cpp
+3 −2 examples/main/README.md
+9 −5 examples/server/CMakeLists.txt
+201 −172 examples/server/README.md
+0 −25 examples/server/deps.sh
+0 −13 examples/server/public/deps_daisyui.min.css
+0 −8,442 examples/server/public/deps_markdown-it.js
+0 −82 examples/server/public/deps_tailwindcss.js
+0 −18,160 examples/server/public/deps_vue.esm-browser.js
+245 −624 examples/server/public/index.html
+3 −0 examples/server/public_simplechat/simplechat.js
+968 −455 examples/server/server.cpp
+6 −0 examples/server/tests/README.md
+4 −0 examples/server/tests/tests.sh
+31 −19 examples/server/tests/unit/test_chat_completion.py
+39 −0 examples/server/tests/unit/test_completion.py
+31 −0 examples/server/tests/unit/test_speculative.py
+2 −238 examples/server/utils.hpp
+268 −0 examples/server/webui/index.html
+2,783 −0 examples/server/webui/package-lock.json
+23 −0 examples/server/webui/package.json
+6 −0 examples/server/webui/postcss.config.js
+0 −0 examples/server/webui/src/completion.js
+456 −0 examples/server/webui/src/main.js
+26 −0 examples/server/webui/src/styles.css
+16 −0 examples/server/webui/tailwind.config.js
+36 −0 examples/server/webui/vite.config.js
+24 −24 ggml/CMakeLists.txt
+8 −0 ggml/include/ggml.h
+35 −0 ggml/src/CMakeLists.txt
+38 −20 ggml/src/ggml-backend-impl.h
+193 −87 ggml/src/ggml-backend-reg.cpp
+299 −251 ggml/src/ggml-cpu/CMakeLists.txt
+0 −1 ggml/src/ggml-cpu/amx/common.h
+36 −38 ggml/src/ggml-cpu/amx/mmq.cpp
+323 −0 ggml/src/ggml-cpu/cpu-feats-x86.cpp
+1 −1 ggml/src/ggml-cpu/ggml-cpu-aarch64.c
+120 −3 ggml/src/ggml-cpu/ggml-cpu.c
+9 −1 ggml/src/ggml-cpu/ggml-cpu.cpp
+0 −1 ggml/src/ggml-cuda/fattn-vec-f16.cuh
+0 −1 ggml/src/ggml-cuda/fattn-vec-f32.cuh
+2 −2 ggml/src/ggml-impl.h
+39 −0 ggml/src/ggml-metal/ggml-metal-impl.h
+439 −22 ggml/src/ggml-metal/ggml-metal.m
+629 −7 ggml/src/ggml-metal/ggml-metal.metal
+2 −1 ggml/src/ggml-sycl/CMakeLists.txt
+31 −12 ggml/src/ggml-sycl/dpct/helper.hpp
+60 −43 ggml/src/ggml-sycl/ggml-sycl.cpp
+8 −8 ggml/src/ggml-sycl/outprod.cpp
+753 −92 ggml/src/ggml-vulkan/ggml-vulkan.cpp
+2 −0 ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
+305 −0 ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
+289 −0 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+21 −6 ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp
+25 −6 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp
+328 −0 ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+74 −21 ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+35 −2 ggml/src/ggml.c
+6 −3 gguf-py/gguf/constants.py
+1 −1 grammars/README.md
+6 −0 grammars/english.gbnf
+6 −1 include/llama.h
+12 −6 scripts/compare-commits.sh
+0 −212 scripts/pod-llama.sh
+0 −418 scripts/server-llm.sh
+0 −3 scripts/sync-ggml-am.sh
+1 −1 scripts/sync-ggml.last
+0 −1 scripts/sync-ggml.sh
+1 −0 src/llama-vocab.cpp
+261 −188 src/llama.cpp
+53 −5 tests/test-backend-ops.cpp
+44 −4 tests/test-chat-template.cpp
+5 −0 tests/test-lora-conversion-inference.sh
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@wllama/wllama",
"version": "2.0.1",
"description": "Low-level WASM binding for llama.cpp",
"version": "2.1.0",
"description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference",
"main": "index.js",
"type": "module",
"directories": {
Expand Down
2 changes: 1 addition & 1 deletion src/multi-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/multi-thread/wllama.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion src/single-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/single-thread/wllama.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions src/wasm-from-cdn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Do not edit this file directly

const WasmFromCDN = {
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/wllama@2.0.1/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/wllama@2.0.1/src/multi-thread/wllama.wasm',
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/wllama@2.1.0/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/wllama@2.1.0/src/multi-thread/wllama.wasm',
};

export default WasmFromCDN;
68 changes: 67 additions & 1 deletion src/wllama.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { test, expect } from 'vitest';
import { Wllama } from './wllama';
import { Wllama, WllamaChatMessage } from './wllama';

const CONFIG_PATHS = {
'single-thread/wllama.wasm': '/src/single-thread/wllama.wasm',
Expand Down Expand Up @@ -229,6 +229,72 @@ test.sequential('allowOffline', async () => {
}
});

test.sequential('formatChat', async () => {
const wllama = new Wllama(CONFIG_PATHS, {
allowOffline: true,
});

await wllama.loadModelFromUrl(TINY_MODEL);
expect(wllama.isModelLoaded()).toBe(true);
const messages: WllamaChatMessage[] = [
{ role: 'system', content: 'You are helpful.' },
{ role: 'user', content: 'Hi!' },
{ role: 'assistant', content: 'Hello!' },
{ role: 'user', content: 'How are you?' },
];

const formatted = await wllama.formatChat(messages, false);
expect(formatted).toBe(
'<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello!<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n'
);

const formatted1 = await wllama.formatChat(messages, true);
expect(formatted1).toBe(
'<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello!<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n'
);

const formatted2 = await wllama.formatChat(messages, true, 'zephyr');
expect(formatted2).toBe(
'<|system|>\nYou are helpful.<|endoftext|>\n<|user|>\nHi!<|endoftext|>\n<|assistant|>\nHello!<|endoftext|>\n<|user|>\nHow are you?<|endoftext|>\n<|assistant|>\n'
);

await wllama.exit();
});

test.sequential('generates chat completion', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
n_ctx: 1024,
});

const config = {
seed: 42,
temp: 0.0,
top_p: 0.95,
top_k: 40,
};

await wllama.samplingInit(config);

const messages: WllamaChatMessage[] = [
{ role: 'system', content: 'You are helpful.' },
{ role: 'user', content: 'Hi!' },
{ role: 'assistant', content: 'Hello!' },
{ role: 'user', content: 'How are you?' },
];
const completion = await wllama.createChatCompletion(messages, {
nPredict: 10,
sampling: config,
});

expect(completion).toBeDefined();
expect(completion).toMatch(/(Sudden|big|scary)+/);
expect(completion.length).toBeGreaterThan(10);

await wllama.exit();
});

test.sequential('cleans up resources', async () => {
const wllama = new Wllama(CONFIG_PATHS);
await wllama.loadModelFromUrl(TINY_MODEL);
Expand Down
50 changes: 49 additions & 1 deletion src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ export interface WllamaConfig {
modelManager?: ModelManager;
}

export interface WllamaChatMessage {
role: 'system' | 'user' | 'assistant';
content: string;
}

export interface AssetsPathConfig {
'single-thread/wllama.wasm': string;
'multi-thread/wllama.wasm'?: string;
Expand Down Expand Up @@ -574,6 +579,23 @@ export class Wllama {
return result;
}

/**
* Make completion for a given chat messages.
*
* NOTE: this function uses the chat template (if available) to format the chat messages. If the template is not available, it will use the default format (chatml). It can throw an error if the chat template is not compatible.
*
* @param messages Chat messages
* @param options
* @returns Output completion text (only the completion part)
*/
async createChatCompletion(
messages: WllamaChatMessage[],
options: ChatCompletionOptions
): Promise<string> {
const prompt = await this.formatChat(messages, true);
return await this.createCompletion(prompt, options);
}

/**
* Make completion for a given text.
* @param prompt Input text
Expand Down Expand Up @@ -962,7 +984,6 @@ export class Wllama {
* Load session from file (virtual file system)
* TODO: add ability to download the file
* @param filePath
*
*/
async sessionLoad(filePath: string): Promise<void> {
this.checkModelLoaded();
Expand All @@ -978,6 +999,33 @@ export class Wllama {
this.nCachedTokens = cachedTokens.length;
}

/**
* Apply chat template to a list of messages
*
* @param messages list of messages
* @param addAssistant whether to add assistant prompt at the end
* @param template (optional) custom template, see llama-server --chat-template argument for more details
* @returns formatted chat
*/
async formatChat(
messages: WllamaChatMessage[],
addAssistant: boolean,
template?: string
): Promise<string> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('chat_format', {
messages: messages,
tmpl: template,
add_ass: addAssistant,
});
if (result.error) {
throw new WllamaError(result.error);
} else if (!result.success) {
throw new WllamaError('formatChat unknown error');
}
return result.formatted_chat;
}

/**
* Set options for underlaying llama_context
*/
Expand Down
4 changes: 2 additions & 2 deletions src/workers-code/generated.ts

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions wllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ extern "C" const char *wllama_action(const char *name, const char *body)
WLLAMA_ACTION(encode);
WLLAMA_ACTION(get_logits);
WLLAMA_ACTION(embeddings);
WLLAMA_ACTION(chat_format);
WLLAMA_ACTION(kv_remove);
WLLAMA_ACTION(kv_clear);
WLLAMA_ACTION(current_status);
Expand Down

0 comments on commit ee31d9f

Please sign in to comment.