Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add benchmark function, used internally #151

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 126 additions & 3 deletions actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,137 @@ json action_current_status(app_t &app, json &body)
};
}

//
// benchmark & perplexity
//

json action_test_benchmark(app_t &app, json &body)
{
std::string type = body.at("type"); // "pp" (prompt proc) or "tg" (tok gen)
int n_samples = body.at("n_samples"); // n_batch in pp and n_predict in pg

llama_kv_cache_clear(app.ctx);
int n_vocab = llama_vocab_n_tokens(app.vocab);
int64_t t_start = ggml_time_ms();

if (type == "pp")
{
llama_batch batch = llama_batch_init(n_samples, 0, 1);
for (int i = 0; i < n_samples; i++)
{
common_batch_add(batch, i % n_vocab, i, {0}, i == n_samples - 1);
}
int ret = llama_decode(app.ctx, batch);
llama_batch_free(batch);
if (ret != 0)
{
return json{{"error", "llama_decode failed with status = " + std::to_string(ret)}};
}
}
else if (type == "tg")
{
llama_batch batch = llama_batch_init(1, 0, 1);
for (int i = 0; i < n_samples; i++)
{
common_batch_clear(batch);
common_batch_add(batch, i % n_vocab, i, {0}, true);
int ret = llama_decode(app.ctx, batch);
if (ret != 0)
{
return json{{"error", "llama_decode failed with status = " + std::to_string(ret)}};
}
}
llama_batch_free(batch);
}
else
{
return json{{"error", "unknown type: " + type}};
}

int64_t t_end = ggml_time_ms();
return json{
{"success", true},
{"t_ms", t_end - t_start},
};
}

json action_test_perplexity(app_t &app, json &body)
{
llama_tokens input = body["input"];
const size_t n = input.size();

int64_t t_start = ggml_time_ms();

if (n < 2)
{
return json{{"error", "Input must contain at least two tokens"}};
}

// Clear existing context to start fresh
llama_kv_cache_clear(app.ctx);
app.tokens.clear();

const int32_t n_vocab = llama_vocab_n_tokens(app.vocab);
double nll = 0.0;

static auto log_softmax = [](int n_vocab, const float *logits, int tok) -> double
{
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i)
{
max_logit = std::max(max_logit, logits[i]);
}
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i)
{
sum_exp += expf(logits[i] - max_logit);
}
return logits[tok] - max_logit - log(sum_exp);
};

for (size_t i = 0; i < n - 1; ++i)
{
// Prepare batch with current token (input[i])
common_batch_clear(app.batch);
common_batch_add(app.batch, input[i], i, {0}, true); // Enable logits for this token

if (llama_decode(app.ctx, app.batch) != 0)
{
return json{{"error", "Decoding failed at position " + std::to_string(i)}};
}

float *logits = llama_get_logits_ith(app.ctx, 0);

// Get true next token (input[i+1])
const int32_t true_token = input[i + 1];

nll += -log_softmax(n_vocab, logits, true_token);
}

// Calculate final metrics
const double cross_entropy = nll / (n - 1);
const double ppl = std::exp(cross_entropy);

int64_t t_end = ggml_time_ms();

return json{
{"success", true},
{"ppl", ppl},
{"nll", nll},
{"cross_entropy", cross_entropy},
{"n_tokens", n - 1},
{"t_ms", t_end - t_start},
};
}

//////////////////////////////////////////

// because we can't support jinja for now, we temporary use an old version of common_chat_apply_template
// TODO: support jinja
std::string common_chat_apply_template_old(const struct llama_model *model,
const std::string &tmpl,
const std::vector<common_chat_msg> &msgs,
bool add_ass)
const std::string &tmpl,
const std::vector<common_chat_msg> &msgs,
bool add_ass)
{
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
Expand Down
1 change: 1 addition & 0 deletions examples/main/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Sidebar from './components/Sidebar';
import { MessagesProvider } from './utils/messages.context';
import { Screen } from './utils/types';
import { useWllama, WllamaProvider } from './utils/wllama.context';
import './utils/benchmark';

function App() {
return (
Expand Down
116 changes: 116 additions & 0 deletions examples/main/src/utils/benchmark.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import { Wllama } from '@wllama/wllama';
import { WLLAMA_CONFIG_PATHS } from '../config';
import { delay } from './utils';

// TODO: this is console-only for now, should we implement a GUI in the future?

const WIKITEXT_URL =
'https://raw.githubusercontent.com/wangfin/QAsystem/refs/heads/master/QAManagement/language_model/data/wikitext-2/valid.txt';

const BENCH_MODELS = [
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf',
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf',
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf',
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q5_K_L.gguf',
];

const BENCH_N_REPEATED = 4;

const BENCH_CONFIGS: { type: 'pp' | 'tg'; n_samples: number }[] = [
{ type: 'pp', n_samples: 32 },
{ type: 'pp', n_samples: 64 },
{ type: 'pp', n_samples: 128 },
{ type: 'pp', n_samples: 256 },
{ type: 'tg', n_samples: 32 },
{ type: 'tg', n_samples: 64 },
{ type: 'tg', n_samples: 128 },
{ type: 'tg', n_samples: 256 },
];

async function loadModel(modelUrl: string) {
const modelFile = modelUrl.split('/').pop();
const wllama = new Wllama(WLLAMA_CONFIG_PATHS);
await wllama.loadModelFromUrl(modelUrl, {
n_batch: 512,
n_ctx: 4096,
progressCallback: ({ total, loaded }) => {
console.log(`Model ${modelFile}: ${Math.round((100 * loaded) / total)}%`);
},
});
return { wllama, modelFile };
}

async function benchmark() {
const output: any[][] = [
['model', 'threads', 'test', 't/s'],
['---', '---', '---', '---'],
];
for (const modelUrl of BENCH_MODELS) {
const [{ wllama, modelFile }] = await Promise.all([
loadModel(modelUrl),
delay(10000), // force delay for CPU to cool down
]);
console.clear();
const nThreads = wllama.getNumThreads();
for (const config of BENCH_CONFIGS) {
const { type, n_samples } = config;
const results: number[] = [];
for (let i = 0; i < BENCH_N_REPEATED; i++) {
console.log('Running', modelFile, config);
const { t_ms } = await wllama._testBenchmark(type, n_samples);
const t_per_tok = n_samples / (t_ms / 1000);
results.push(t_per_tok);
console.log('Run ', i, 'pref:', t_per_tok, 't/s');
}
const t_avg = results.reduce((a, b) => a + b, 0) / results.length;
const t_plus_minus = Math.abs(
Math.max(...results) - Math.min(...results)
);
output.push([
modelFile,
nThreads,
`${type} ${n_samples}`,
`${t_avg.toFixed(2)} ± ${t_plus_minus.toFixed(2)}`,
]);
}
wllama.exit();
}

console.table(output);
const markdown = output
.map((row) => '| ' + row.join(' | ') + ' |')
.join('\n');
console.log(markdown);
}

async function perplexity() {
const output: any[][] = [
['model', 'PPL', 'n_tokens'],
['---', '---', '---'],
];
const LIMIT_TOKENS = 2048;
const wikitext = await fetch(WIKITEXT_URL).then((res) => res.text());
console.log('Loaded wikitext:', wikitext.substring(0, 100), '...');
for (const modelUrl of BENCH_MODELS) {
const { wllama, modelFile } = await loadModel(modelUrl);
console.clear();
let tokens = await wllama.tokenize(
wikitext.substring(0, LIMIT_TOKENS * 16)
);
tokens = tokens.slice(0, LIMIT_TOKENS);
console.log('Running', modelFile, 'n_tokens', tokens.length);
const { ppl } = await wllama._testPerplexity(tokens);
console.log('PPL:', ppl);
output.push([modelFile, ppl, tokens.length]);
wllama.exit();
}

console.table(output);
const markdown = output
.map((row) => '| ' + row.join(' | ') + ' |')
.join('\n');
console.log(markdown);
}

(window as any).__benchmark = benchmark;
(window as any).__perplexity = perplexity;
2 changes: 1 addition & 1 deletion llama.cpp
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@wllama/wllama",
"version": "2.1.3",
"version": "2.1.4",
"description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference",
"main": "index.js",
"type": "module",
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ services:
mkdir -p wasm/single-thread
cd wasm/single-thread

export SHARED_EMCC_CFLAGS="--no-entry -O3 -msimd128 -fno-rtti -DNDEBUG -flto=full -frtti -fwasm-exceptions -sEXPORT_ALL=1 -sEXPORT_ES6=0 -sMODULARIZE=0 -sINITIAL_MEMORY=128MB -sMAXIMUM_MEMORY=4096MB -sALLOW_MEMORY_GROWTH=1 -sFORCE_FILESYSTEM=1 -sEXPORTED_FUNCTIONS=_main,_wllama_start,_wllama_action,_wllama_exit,_wllama_debug -sEXPORTED_RUNTIME_METHODS=ccall,cwrap -sNO_EXIT_RUNTIME=1"
export SHARED_EMCC_CFLAGS="--no-entry -O3 -msimd128 -DNDEBUG -flto=full -frtti -fwasm-exceptions -sEXPORT_ALL=1 -sEXPORT_ES6=0 -sMODULARIZE=0 -sINITIAL_MEMORY=128MB -sMAXIMUM_MEMORY=4096MB -sALLOW_MEMORY_GROWTH=1 -sFORCE_FILESYSTEM=1 -sEXPORTED_FUNCTIONS=_main,_wllama_start,_wllama_action,_wllama_exit,_wllama_debug -sEXPORTED_RUNTIME_METHODS=ccall,cwrap -sNO_EXIT_RUNTIME=1"

# emcc --clear-cache

Expand Down
2 changes: 1 addition & 1 deletion src/multi-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/multi-thread/wllama.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion src/single-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/single-thread/wllama.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions src/wasm-from-cdn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Do not edit this file directly

const WasmFromCDN = {
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].3/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].3/src/multi-thread/wllama.wasm',
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].4/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].4/src/multi-thread/wllama.wasm',
};

export default WasmFromCDN;
36 changes: 36 additions & 0 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ export class Wllama {
private config: WllamaConfig;
private pathConfig: AssetsPathConfig;
private useMultiThread: boolean = false;
private nbThreads: number = 1;
private useEmbeddings: boolean = false;
// available when loaded
private loadedContextInfo: LoadedContextInfo = null as any;
Expand Down Expand Up @@ -344,6 +345,18 @@ export class Wllama {
return this.useMultiThread;
}

/**
* Get number of threads used in the current context.
*
* NOTE: This can only being used after `loadModel` is called.
*
* @returns number of threads
*/
getNumThreads(): number {
this.checkModelLoaded();
return this.useMultiThread ? this.nbThreads : 1;
}

/**
* Check if the current model uses encoder-decoder architecture
*
Expand Down Expand Up @@ -478,6 +491,7 @@ export class Wllama {
}
const hwConccurency = Math.floor((navigator.hardwareConcurrency || 1) / 2);
const nbThreads = config.n_threads ?? hwConccurency;
this.nbThreads = nbThreads;
this.useMultiThread =
supportMultiThread && hasPathMultiThread && nbThreads > 1;
const mPathConfig = this.useMultiThread
Expand Down Expand Up @@ -1065,6 +1079,28 @@ export class Wllama {
return await this.proxy.wllamaDebug();
}

/**
* benchmark function, only used internally
*/
async _testBenchmark(
type: 'tg' | 'pp',
nSamples: number
): Promise<{ t_ms: number }> {
this.checkModelLoaded();
return await this.proxy.wllamaAction('test_benchmark', {
type,
n_samples: nSamples,
});
}

/**
* perplexity function, only used internally
*/
async _testPerplexity(input: number[]): Promise<{ ppl: number }> {
this.checkModelLoaded();
return await this.proxy.wllamaAction('test_perplexity', { input });
}

///// Prompt cache utils /////
private async getCachedTokens(): Promise<number[]> {
this.checkModelLoaded();
Expand Down
4 changes: 2 additions & 2 deletions src/workers-code/generated.ts

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions wllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ extern "C" const char *wllama_action(const char *name, const char *body)
WLLAMA_ACTION(current_status);
WLLAMA_ACTION(session_save);
WLLAMA_ACTION(session_load);
WLLAMA_ACTION(test_benchmark);
WLLAMA_ACTION(test_perplexity);
result = std::string(res.dump());
return result.c_str();
}
Expand Down