Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to Swift LLM app #388

Merged
merged 1 commit into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
objects = {

/* Begin PBXBuildFile section */
8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */; };
0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */; };
0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */; };
8DCF4C452B99289E00427D77 /* InferenceExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */; };
8DCF4C472B99289E00427D77 /* ConversationScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C462B99289E00427D77 /* ConversationScreen.swift */; };
8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8DCF4C482B99289E00427D77 /* Assets.xcassets */; };
Expand All @@ -17,7 +18,8 @@
/* End PBXBuildFile section */

/* Begin PBXFileReference section */
8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = "gemma-2b-it-cpu-int4.bin"; sourceTree = "<group>"; };
0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "gemma-1.1-2b-it-gpu-int4.bin"; path = "../../../../../../Downloads/gemma-1.1-2b-it-gpu-int4.bin"; sourceTree = "<group>"; };
84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.release.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.release.xcconfig"; sourceTree = "<group>"; };
8DCF4C412B99289E00427D77 /* InferenceExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = InferenceExample.app; sourceTree = BUILT_PRODUCTS_DIR; };
8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceExampleApp.swift; sourceTree = "<group>"; };
8DCF4C462B99289E00427D77 /* ConversationScreen.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationScreen.swift; sourceTree = "<group>"; };
Expand All @@ -26,33 +28,47 @@
8DCF4C4C2B99289E00427D77 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
8DCF4C572B992B9C00427D77 /* ConversationViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationViewModel.swift; sourceTree = "<group>"; };
8DCF4C5B2B9939D700427D77 /* OnDeviceModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OnDeviceModel.swift; sourceTree = "<group>"; };
E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.debug.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.debug.xcconfig"; sourceTree = "<group>"; };
F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_InferenceExample.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
8DCF4C3E2B99289E00427D77 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */

/* Begin PBXGroup section */
32B29128DD33338DDF219030 /* Frameworks */ = {
isa = PBXGroup;
children = (
F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */,
);
name = Frameworks;
sourceTree = "<group>";
};
540B7F154909C4C9EF376B57 /* Pods */ = {
isa = PBXGroup;
children = (
E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */,
84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */,
);
path = Pods;
sourceTree = "<group>";
};
8DCF4C382B99289D00427D77 = {
isa = PBXGroup;
children = (
0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */,
8DCF4C432B99289E00427D77 /* InferenceExample */,
8DCF4C422B99289E00427D77 /* Products */,
540B7F154909C4C9EF376B57 /* Pods */,
8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */,
32B29128DD33338DDF219030 /* Frameworks */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -93,6 +109,7 @@
isa = PBXNativeTarget;
buildConfigurationList = 8DCF4C502B99289E00427D77 /* Build configuration list for PBXNativeTarget "InferenceExample" */;
buildPhases = (
3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */,
8DCF4C3D2B99289E00427D77 /* Sources */,
8DCF4C3E2B99289E00427D77 /* Frameworks */,
8DCF4C3F2B99289E00427D77 /* Resources */,
Expand Down Expand Up @@ -146,12 +163,37 @@
files = (
8DCF4C4D2B99289E00427D77 /* Preview Assets.xcassets in Resources */,
8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */,
8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */,
0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */

/* Begin PBXShellScriptBuildPhase section */
3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputFileListPaths = (
);
inputPaths = (
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
"${PODS_ROOT}/Manifest.lock",
);
name = "[CP] Check Pods Manifest.lock";
outputFileListPaths = (
);
outputPaths = (
"$(DERIVED_FILE_DIR)/Pods-InferenceExample-checkManifestLockResult.txt",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
/* End PBXShellScriptBuildPhase section */

/* Begin PBXSourcesBuildPhase section */
8DCF4C3D2B99289E00427D77 /* Sources */ = {
isa = PBXSourcesBuildPhase;
Expand Down Expand Up @@ -283,13 +325,15 @@
};
8DCF4C512B99289E00427D77 /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = InferenceExample/InferenceExample.entitlements;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\"";
DEVELOPMENT_TEAM = M3D535GFVK;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
Expand All @@ -307,7 +351,7 @@
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
Expand All @@ -319,13 +363,15 @@
};
8DCF4C522B99289E00427D77 /* Release */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = InferenceExample/InferenceExample.entitlements;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\"";
DEVELOPMENT_TEAM = M3D535GFVK;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
Expand All @@ -343,7 +389,7 @@
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
let response = try await chat.sendMessage(text)
let response = try await chat.sendMessage(text, progress : { [weak self] partialResult in
guard let self = self else { return }
DispatchQueue.main.async {
self.messages[self.messages.count - 1].message = partialResult
}
})

// replace pending message with model response
messages[messages.count - 1].message = response
Expand Down
44 changes: 33 additions & 11 deletions examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,27 @@ import MediaPipeTasksGenAI

final class OnDeviceModel {

private var inference: LlmInference! = {
let path = Bundle.main.path(forResource: "gemma-2b-it-cpu-int4", ofType: "bin")!
let llmOptions = LlmInference.Options(modelPath: path)
return LlmInference(options: llmOptions)
}()

func generateResponse(prompt: String) async throws -> String {
private var cachedInference: LlmInference?

private var inference: LlmInference
{
get throws {
if let cached = cachedInference {
return cached
} else {
let path = Bundle.main.path(forResource: "gemma-1.1-2b-it-gpu-int4", ofType: "bin")!
let llmOptions = LlmInference.Options(modelPath: path)
cachedInference = try LlmInference(options: llmOptions)
return cachedInference!
}
}
}

func generateResponse(prompt: String, progress: @escaping (String) -> Void) async throws -> String {
var partialResult = ""

let inference = try inference
return try await withCheckedThrowingContinuation { continuation in
do {
try inference.generateResponseAsync(inputText: prompt) { partialResponse, error in
Expand All @@ -34,6 +47,7 @@ final class OnDeviceModel {
}
if let partial = partialResponse {
partialResult += partial
progress(partialResult.trimmingCharacters(in: .whitespacesAndNewlines))
}
} completion: {
let aggregate = partialResult.trimmingCharacters(in: .whitespacesAndNewlines)
Expand Down Expand Up @@ -62,16 +76,24 @@ final class Chat {
self.model = model
}

private func composeUserTurn(_ newMessage: String) -> String {
return "<start_of_turn>user\n\(newMessage)<end_of_turn>\n"
}

private func composeModelTurn(_ newMessage: String) -> String {
return "<start_of_turn>model\n\(newMessage)<end_of_turn>\n"
}

private func compositePrompt(newMessage: String) -> String {
return history.joined(separator: "\n") + "\n" + newMessage
}

func sendMessage(_ text: String) async throws -> String {
let prompt = compositePrompt(newMessage: text)
let reply = try await model.generateResponse(prompt: prompt)

history.append(text)
history.append(reply)
func sendMessage(_ text: String, progress: @escaping (String) -> Void) async throws -> String {
let prompt = compositePrompt(newMessage: composeUserTurn(text))
let reply = try await model.generateResponse(prompt: prompt, progress: progress)

history = [prompt, composeModelTurn(reply)]

print("Prompt: \(prompt)")
print("Reply: \(reply)")
Expand Down
10 changes: 5 additions & 5 deletions examples/llm_inference/ios/Podfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PODS:
- MediaPipeTasksGenAI (0.10.11):
- MediaPipeTasksGenAIC (= 0.10.11)
- MediaPipeTasksGenAIC (0.10.11)
- MediaPipeTasksGenAI (0.10.14):
- MediaPipeTasksGenAIC (= 0.10.14)
- MediaPipeTasksGenAIC (0.10.14)

DEPENDENCIES:
- MediaPipeTasksGenAI
Expand All @@ -12,8 +12,8 @@ SPEC REPOS:
- MediaPipeTasksGenAIC

SPEC CHECKSUMS:
MediaPipeTasksGenAI: 9fb3fc0e9f9329d0b3f89d741dbdb4cb4429b87a
MediaPipeTasksGenAIC: 9bb1f037b742d7d642c8b8fbec8b4626f73c18c5
MediaPipeTasksGenAI: 8cd77fa32ea21f7a6319b025aa28cfc3e20ab73b
MediaPipeTasksGenAIC: 270ec81f85e96fac283945702e34112ebbfd5e77

PODFILE CHECKSUM: b561fe84c5e19b81e1111ba0f8f21564f7006b85

Expand Down
Loading