Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support chat format #72

Merged
merged 6 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 75 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggi
For create a GGUF model manually, for example in Llama 2:

Download the Llama 2 model

1. Request access from [here](https://ai.meta.com/llama)
2. Download the model from HuggingFace [here](https://huggingface.co/meta-llama/Llama-2-7b-chat) (`Llama-2-7b-chat`)

Convert the model to ggml format

```bash
# Start with submodule in this repo (or you can clone the repo https://github.com/ggerganov/llama.cpp.git)
yarn && yarn bootstrap
Expand Down Expand Up @@ -76,26 +78,53 @@ const context = await initLlama({
// embedding: true, // use embedding
})

// Do completion
const { text, timings } = await context.completion(
const stopWords = ['</s>', '<|end|>', '<|eot_id|>', '<|end_of_text|>', '<|im_end|>', '<|EOT|>', '<|END_OF_TURN_TOKEN|>', '<|end_of_turn|>', '<|endoftext|>']

// Do chat completion
const msgResult = await context.completion(
{
messages: [
{
role: 'system',
content: 'This is a conversation between user and assistant, a friendly chatbot.',
},
{
role: 'user',
content: 'Hello!',
},
],
n_predict: 100,
stop: stopWords,
// ...other params
},
(data) => {
// This is a partial completion callback
const { token } = data
},
)
console.log('Result:', msgResult.text)
console.log('Timings:', msgResult.timings)

// Or do text completion
const textResult = await context.completion(
{
prompt: 'This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.\n\nUser: Hello!\nLlama:',
n_predict: 100,
stop: ['</s>', 'Llama:', 'User:'],
// n_threads: 4,
stop: [...stopWords, 'Llama:', 'User:'],
// ...other params
},
(data) => {
// This is a partial completion callback
const { token } = data
},
)
console.log('Result:', text)
console.log('Timings:', timings)
console.log('Result:', textResult.text)
console.log('Timings:', textResult.timings)
```

The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp, so you can map its API to LlamaContext:

- `/completion`: `context.completion(params, partialCompletionCallback)`
- `/completion` and `/chat/completions`: `context.completion(params, partialCompletionCallback)`
- `/tokenize`: `context.tokenize(content)`
- `/detokenize`: `context.detokenize(tokens)`
- `/embedding`: `context.embedding(content)`
Expand All @@ -110,6 +139,7 @@ Please visit the [Documentation](docs/API) for more details.
You can also visit the [example](example) to see how to use it.

Run the example:

```bash
yarn && yarn bootstrap

Expand Down Expand Up @@ -142,7 +172,9 @@ You can see [GBNF Guide](https://github.com/ggerganov/llama.cpp/tree/master/gram
```js
import { initLlama, convertJsonSchemaToGrammar } from 'llama.rn'

const schema = { /* JSON Schema, see below */ }
const schema = {
/* JSON Schema, see below */
}

const context = await initLlama({
model: 'file://<path to gguf model>',
Expand All @@ -153,7 +185,7 @@ const context = await initLlama({
grammar: convertJsonSchemaToGrammar({
schema,
propOrder: { function: 0, arguments: 1 },
})
}),
})

const { text } = await context.completion({
Expand All @@ -171,80 +203,81 @@ console.log('Result:', text)
{
oneOf: [
{
type: "object",
name: "get_current_weather",
description: "Get the current weather in a given location",
type: 'object',
name: 'get_current_weather',
description: 'Get the current weather in a given location',
properties: {
function: {
const: "get_current_weather",
const: 'get_current_weather',
},
arguments: {
type: "object",
type: 'object',
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
type: 'string',
description: 'The city and state, e.g. San Francisco, CA',
},
unit: {
type: "string",
enum: ["celsius", "fahrenheit"],
type: 'string',
enum: ['celsius', 'fahrenheit'],
},
},
required: ["location"],
required: ['location'],
},
},
},
{
type: "object",
name: "create_event",
description: "Create a calendar event",
type: 'object',
name: 'create_event',
description: 'Create a calendar event',
properties: {
function: {
const: "create_event",
const: 'create_event',
},
arguments: {
type: "object",
type: 'object',
properties: {
title: {
type: "string",
description: "The title of the event",
type: 'string',
description: 'The title of the event',
},
date: {
type: "string",
description: "The date of the event",
type: 'string',
description: 'The date of the event',
},
time: {
type: "string",
description: "The time of the event",
type: 'string',
description: 'The time of the event',
},
},
required: ["title", "date", "time"],
required: ['title', 'date', 'time'],
},
},
},
{
type: "object",
name: "image_search",
description: "Search for an image",
type: 'object',
name: 'image_search',
description: 'Search for an image',
properties: {
function: {
const: "image_search",
const: 'image_search',
},
arguments: {
type: "object",
type: 'object',
properties: {
query: {
type: "string",
description: "The search query",
type: 'string',
description: 'The search query',
},
},
required: ["query"],
required: ['query'],
},
},
},
],
}
```

</details>

<details>
Expand All @@ -268,6 +301,7 @@ string ::= "\"" (
2 ::= "{" space "\"function\"" space ":" space 2-function "," space "\"arguments\"" space ":" space 2-arguments "}" space
root ::= 0 | 1 | 2
```

</details>

## Mock `llama.rn`
Expand All @@ -281,12 +315,14 @@ jest.mock('llama.rn', () => require('llama.rn/jest/mock'))
## NOTE

iOS:

- The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
- Metal:
- We have tested to know some devices is not able to use Metal ('params.n_gpu_layers > 0') due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
- It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.

Android:

- Currently only supported arm64-v8a / x86_64 platform, this means you can't initialize a context on another platforms. The 64-bit platform are recommended because it can allocate more memory for the model.
- No integrated any GPU backend yet.

Expand Down
13 changes: 13 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ public WritableMap getModelDetails() {
return modelDetails;
}

public String getFormattedChat(ReadableArray messages, String chatTemplate) {
ReadableMap[] msgs = new ReadableMap[messages.size()];
for (int i = 0; i < messages.size(); i++) {
msgs[i] = messages.getMap(i);
}
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
}

private void emitPartialCompletion(WritableMap tokenResult) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
Expand Down Expand Up @@ -316,6 +324,11 @@ protected static native long initContext(
protected static native WritableMap loadModelDetails(
long contextPtr
);
protected static native String getFormattedChat(
long contextPtr,
ReadableMap[] messages,
String chatTemplate
);
protected static native WritableMap loadSession(
long contextPtr,
String path
Expand Down
32 changes: 32 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "initContext");
}

public void getFormattedChat(double id, final ReadableArray messages, final String chatTemplate, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.getFormattedChat(messages, chatTemplate);
} catch (Exception e) {
exception = e;
return null;
}
}

@Override
protected void onPostExecute(String result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "getFormattedChat-" + contextId);
}

public void loadSession(double id, final String path, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
51 changes: 51 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ static inline void putDouble(JNIEnv *env, jobject map, const char *key, double v
env->CallVoidMethod(map, putDoubleMethod, jKey, value);
}

// Method to put boolean into WritableMap
static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");

jstring jKey = env->NewStringUTF(key);

env->CallVoidMethod(map, putBooleanMethod, jKey, value);
}

// Method to put WriteableMap into WritableMap
static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
Expand Down Expand Up @@ -208,11 +218,52 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
putString(env, result, "desc", desc);
putDouble(env, result, "size", llama_model_size(llama->model));
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
putMap(env, result, "metadata", meta);

return reinterpret_cast<jobject>(result);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_getFormattedChat(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jobjectArray messages,
jstring chat_template
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

std::vector<llama_chat_msg> chat;

int messages_len = env->GetArrayLength(messages);
for (int i = 0; i < messages_len; i++) {
jobject msg = env->GetObjectArrayElement(messages, i);
jclass msgClass = env->GetObjectClass(msg);

jmethodID getRoleMethod = env->GetMethodID(msgClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring roleKey = env->NewStringUTF("role");
jstring contentKey = env->NewStringUTF("content");

jstring role_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, roleKey);
jstring content_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, contentKey);

const char *role = env->GetStringUTFChars(role_str, nullptr);
const char *content = env->GetStringUTFChars(content_str, nullptr);

chat.push_back({ role, content });

env->ReleaseStringUTFChars(role_str, role);
env->ReleaseStringUTFChars(content_str, content);
}

const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);

return env->NewStringUTF(formatted_chat.c_str());
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_loadSession(
JNIEnv *env,
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
Expand Down
Loading
Loading