Skip to content

Commit

Permalink
feat: add bench method (#35)
Browse files Browse the repository at this point in the history
* feat: add bench method

* feat(example): improve output & make bench result in bubble copyable

* chore: avoid unnecessary type convert

* feat(example): add heat up time output

* feat: move json parse to api

* docs(api): update
  • Loading branch information
jhen0409 authored Dec 19, 2023
1 parent a57171e commit 9d0c7df
Show file tree
Hide file tree
Showing 19 changed files with 417 additions and 56 deletions.
5 changes: 5 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ public WritableMap embedding(String text) {
return result;
}

public String bench(int pp, int tg, int pl, int nr) {
return bench(this.context, pp, tg, pl, nr);
}

public void release() {
freeContext(context);
}
Expand Down Expand Up @@ -329,5 +333,6 @@ protected static native WritableMap doCompletion(
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableArray embedding(long contextPtr, String text);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
}
32 changes: 32 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,38 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "embedding-" + contextId);
}

public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.bench((int) pp, (int) tg, (int) pl, (int) nr);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(String result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "bench-" + contextId);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
16 changes: 16 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,22 @@ Java_com_rnllama_LlamaContext_embedding(
return result;
}

JNIEXPORT jstring JNICALL
Java_com_rnllama_LlamaContext_bench(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jint pp,
jint tg,
jint pl,
jint nr
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
std::string result = llama->bench(pp, tg, pl, nr);
return env->NewStringUTF(result.c_str());
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_freeContext(
JNIEnv *env, jobject thiz, jlong context_ptr) {
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
}

@ReactMethod
public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) {
rnllama.bench(id, pp, tg, pl, nr, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnllama.releaseContext(id, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
}

@ReactMethod
public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) {
rnllama.bench(id, pp, tg, pl, nr, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnllama.releaseContext(id, promise);
Expand Down
118 changes: 118 additions & 0 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@

namespace rnllama {

static void llama_batch_clear(llama_batch *batch) {
batch->n_tokens = 0;
}

static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
batch->token [batch->n_tokens] = id;
batch->pos [batch->n_tokens] = pos;
batch->n_seq_id[batch->n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); i++) {
batch->seq_id[batch->n_tokens][i] = seq_ids[i];
}
batch->logits [batch->n_tokens] = logits ? 1 : 0;
batch->n_tokens += 1;
}

// NOTE: Edit from https://github.com/ggerganov/llama.cpp/blob/master/examples/server/server.cpp

static void log(const char *level, const char *function, int line,
Expand Down Expand Up @@ -506,6 +521,109 @@ struct llama_rn_context
std::vector<float> embedding(data, data + n_embd);
return embedding;
}

std::string bench(int pp, int tg, int pl, int nr)
{
if (is_predicting) {
LOG_ERROR("cannot benchmark while predicting", "");
return std::string("[]");
}

is_predicting = true;

double pp_avg = 0;
double tg_avg = 0;

double pp_std = 0;
double tg_std = 0;

// TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
llama_batch batch = llama_batch_init(512, 0, 1);

for (int i = 0; i < nr; i++)
{
llama_batch_clear(&batch);

const int n_tokens = pp;

for (int i = 0; i < n_tokens; i++)
{
llama_batch_add(&batch, 0, i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = 1; // true

llama_kv_cache_clear(ctx);

const int64_t t_pp_start = llama_time_us();
if (llama_decode(ctx, batch) != 0)
{
LOG_ERROR("llama_decode() failed during prompt", "");
}
const int64_t t_pp_end = llama_time_us();
llama_kv_cache_clear(ctx);

if (is_interrupted) break;

const int64_t t_tg_start = llama_time_us();

for (int i = 0; i < tg; i++)
{
llama_batch_clear(&batch);

for (int j = 0; j < pl; j++)
{
llama_batch_add(&batch, 0, i, {j}, true);
}

if (llama_decode(ctx, batch) != 0)
{
LOG_ERROR("llama_decode() failed during text generation", "");
}
if (is_interrupted) break;
}

const int64_t t_tg_end = llama_time_us();

llama_kv_cache_clear(ctx);

const double t_pp = (t_pp_end - t_pp_start) / 1000000.0;
const double t_tg = (t_tg_end - t_tg_start) / 1000000.0;

const double speed_pp = pp / t_pp;
const double speed_tg = (pl * tg) / t_tg;

pp_avg += speed_pp;
tg_avg += speed_tg;

pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
}

pp_avg /= nr;
tg_avg /= nr;

if (nr > 1) {
pp_std = sqrt(pp_std / (nr - 1) - pp_avg * pp_avg * nr / (nr - 1));
tg_std = sqrt(tg_std / (nr - 1) - tg_avg * tg_avg * nr / (nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}

if (is_interrupted) llama_kv_cache_clear(ctx);
is_predicting = false;

char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
return std::string("[\"") + model_desc + std::string("\",") +
std::to_string(llama_model_size(model)) + std::string(",") +
std::to_string(llama_model_n_params(model)) + std::string(",") +
std::to_string(pp_avg) + std::string(",") +
std::to_string(pp_std) + std::string(",") +
std::to_string(tg_avg) + std::string(",") +
std::to_string(tg_std) +
std::string("]");
}
};

}
Expand Down
37 changes: 30 additions & 7 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ llama.rn

### Type Aliases

- [BenchResult](README.md#benchresult)
- [CompletionParams](README.md#completionparams)
- [ContextParams](README.md#contextparams)
- [TokenData](README.md#tokendata)
Expand All @@ -24,13 +25,35 @@ llama.rn

## Type Aliases

### BenchResult

Ƭ **BenchResult**: `Object`

#### Type declaration

| Name | Type |
| :------ | :------ |
| `modelDesc` | `string` |
| `modelNParams` | `number` |
| `modelSize` | `number` |
| `ppAvg` | `number` |
| `ppStd` | `number` |
| `tgAvg` | `number` |
| `tgStd` | `number` |

#### Defined in

[index.ts:43](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L43)

___

### CompletionParams

Ƭ **CompletionParams**: `Omit`<`NativeCompletionParams`, ``"emit_partial_completion"``\>

#### Defined in

[index.ts:40](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L40)
[index.ts:41](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L41)

___

Expand All @@ -40,7 +63,7 @@ ___

#### Defined in

[index.ts:38](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L38)
[index.ts:39](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L39)

___

Expand All @@ -57,7 +80,7 @@ ___

#### Defined in

[index.ts:28](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L28)
[index.ts:29](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L29)

## Functions

Expand All @@ -79,7 +102,7 @@ ___

#### Defined in

[grammar.ts:134](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L134)
[grammar.ts:134](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L134)

___

Expand All @@ -99,7 +122,7 @@ ___

#### Defined in

[index.ts:127](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L127)
[index.ts:160](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L160)

___

Expand All @@ -113,7 +136,7 @@ ___

#### Defined in

[index.ts:143](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L143)
[index.ts:176](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L176)

___

Expand All @@ -133,4 +156,4 @@ ___

#### Defined in

[index.ts:123](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L123)
[index.ts:156](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L156)
Loading

0 comments on commit 9d0c7df

Please sign in to comment.