Skip to content

Commit

Permalink
feat: make partial completion callback optional
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Sep 4, 2023
1 parent 95519e4 commit dddcb20
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
10 changes: 8 additions & 2 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ private void emitPartialCompletion(WritableMap tokenResult) {

private static class PartialCompletionCallback {
LlamaContext context;
boolean emitNeeded;

public PartialCompletionCallback(LlamaContext context) {
public PartialCompletionCallback(LlamaContext context, boolean emitNeeded) {
this.context = context;
this.emitNeeded = emitNeeded;
}

void onPartialCompletion(WritableMap tokenResult) {
if (!emitNeeded) return;
context.emitPartialCompletion(tokenResult);
}
}
Expand Down Expand Up @@ -151,7 +154,10 @@ public WritableMap completion(ReadableMap params) {
// double[][] logit_bias,
logit_bias,
// PartialCompletionCallback partial_completion_callback
new PartialCompletionCallback(this)
new PartialCompletionCallback(
this,
params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
)
);
}

Expand Down
1 change: 1 addition & 0 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ - (NSArray *)supportedEvents {
@autoreleasepool {
NSMutableDictionary* completionResult = [context completion:completionParams
onToken:^(NSMutableDictionary *tokenResult) {
if (completionParams[@"emit_partial_completion"] == false) return;
dispatch_async(dispatch_get_main_queue(), ^{
[self sendEventWithName:@"@RNLlama_onToken"
body:@{
Expand Down
2 changes: 2 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ export type NativeCompletionParams = {

ignore_eos?: boolean
logit_bias?: Array<Array<number>>

emit_partial_completion: boolean
}

export type NativeCompletionTokenProbItem = {
Expand Down
15 changes: 9 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type TokenNativeEvent = {

export type ContextParams = NativeContextParams

export type CompletionParams = NativeCompletionParams
export type CompletionParams = Omit<NativeCompletionParams, 'emit_partial_response'>

export class LlamaContext {
id: number
Expand All @@ -52,25 +52,28 @@ export class LlamaContext {

async completion(
params: CompletionParams,
callback: (data: TokenData) => void,
callback?: (data: TokenData) => void,
) {
let tokenListener: any = EventEmitter.addListener(
let tokenListener: any = callback && EventEmitter.addListener(
EVENT_ON_TOKEN,
(evt: TokenNativeEvent) => {
const { contextId, tokenResult } = evt
if (contextId !== this.id) return
callback(tokenResult)
},
)
const promise = RNLlama.completion(this.id, params)
const promise = RNLlama.completion(this.id, {
...params,
emit_partial_completion: !!callback,
})
return promise
.then((completionResult) => {
tokenListener.remove()
tokenListener?.remove()
tokenListener = null
return completionResult
})
.catch((err: any) => {
tokenListener.remove()
tokenListener?.remove()
tokenListener = null
throw err
})
Expand Down

0 comments on commit dddcb20

Please sign in to comment.