Skip to content

Commit

Permalink
Do not use Sessions for LLM Inference app (#456)
Browse files Browse the repository at this point in the history
Sessions lead to crashes in append-only chat apps. We use should the non-Session API as this creates a new implicit session for every message, thereby reducing the chance that we run out of tokens in the KV cache.
  • Loading branch information
schmidt-sebastian authored Sep 30, 2024
1 parent 3576115 commit 8964924
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ConversationViewModel: ObservableObject {
do {
let model = try OnDeviceModel(model: Model.gemma)
self.model = model
chat = try Chat(model: model)
chat = try Chat(inference: model.inference)
} catch let error as InferenceError {
self.error = error
} catch {
Expand Down Expand Up @@ -154,7 +154,7 @@ class ConversationViewModel: ObservableObject {
return
}
do {
chat = try Chat(model: model)
chat = try Chat(inference: model.inference)
messages.removeAll()
} catch {
self.error = InferenceError.mediaPipeTasksError(error: error)
Expand Down
17 changes: 6 additions & 11 deletions examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@ struct OnDeviceModel {
/// Represents a chat session using an instance of `OnDeviceModel`. It manages a MediaPipe
/// `LlmInference.Session` under the hood and passes all response generation queries to the session.
final class Chat {
/// The on device model using which this chat session was created.
private let model: OnDeviceModel

/// MediaPipe session managed by the current instance.
private var session: LlmInference.Session

init(model: OnDeviceModel) throws {
self.model = model
session = try LlmInference.Session(llmInference: model.inference)
/// The on device inference engine for this chat session
private let inference: LlmInference

init(inference: LlmInference) throws {
self.inference = inference
}

/// Sends a streaming response generation query to the underlying MediaPipe
Expand All @@ -48,8 +44,7 @@ final class Chat {
/// - Throws: A MediaPipe `GenAiInferenceError` if the query cannot be added to the current
/// session.
func sendMessage(_ text: String) async throws -> AsyncThrowingStream<String, any Error> {
try session.addQueryChunk(inputText: text)
let resultStream = session.generateResponseAsync()
let resultStream = inference.generateResponseAsync(inputText: text)
return resultStream
}
}

0 comments on commit 8964924

Please sign in to comment.