From 58d9638e0720c3164cd2c4bed35cd8bed5340971 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 11 May 2024 09:59:43 -0600 Subject: [PATCH] Improvements to Swift LLM app --- .../project.pbxproj | 58 +++++++++++++++++-- .../ConversationViewModel.swift | 7 ++- .../ios/InferenceExample/OnDeviceModel.swift | 44 ++++++++++---- examples/llm_inference/ios/Podfile.lock | 10 ++-- 4 files changed, 96 insertions(+), 23 deletions(-) diff --git a/examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj b/examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj index 0b785215..a3a25582 100644 --- a/examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj +++ b/examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj @@ -7,7 +7,8 @@ objects = { /* Begin PBXBuildFile section */ - 8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */; }; + 0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */; }; + 0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */; }; 8DCF4C452B99289E00427D77 /* InferenceExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */; }; 8DCF4C472B99289E00427D77 /* ConversationScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C462B99289E00427D77 /* ConversationScreen.swift */; }; 8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8DCF4C482B99289E00427D77 /* Assets.xcassets */; }; @@ -17,7 +18,8 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ - 8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = "gemma-2b-it-cpu-int4.bin"; sourceTree = ""; }; + 0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "gemma-1.1-2b-it-gpu-int4.bin"; path = "../../../../../../Downloads/gemma-1.1-2b-it-gpu-int4.bin"; sourceTree = ""; }; + 84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.release.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.release.xcconfig"; sourceTree = ""; }; 8DCF4C412B99289E00427D77 /* InferenceExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = InferenceExample.app; sourceTree = BUILT_PRODUCTS_DIR; }; 8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceExampleApp.swift; sourceTree = ""; }; 8DCF4C462B99289E00427D77 /* ConversationScreen.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationScreen.swift; sourceTree = ""; }; @@ -26,6 +28,8 @@ 8DCF4C4C2B99289E00427D77 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 8DCF4C572B992B9C00427D77 /* ConversationViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationViewModel.swift; sourceTree = ""; }; 8DCF4C5B2B9939D700427D77 /* OnDeviceModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OnDeviceModel.swift; sourceTree = ""; }; + E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.debug.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.debug.xcconfig"; sourceTree = ""; }; + F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_InferenceExample.framework; sourceTree = BUILT_PRODUCTS_DIR; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -33,15 +37,26 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + 0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ + 32B29128DD33338DDF219030 /* Frameworks */ = { + isa = PBXGroup; + children = ( + F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; 540B7F154909C4C9EF376B57 /* Pods */ = { isa = PBXGroup; children = ( + E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */, + 84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */, ); path = Pods; sourceTree = ""; @@ -49,10 +64,11 @@ 8DCF4C382B99289D00427D77 = { isa = PBXGroup; children = ( + 0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */, 8DCF4C432B99289E00427D77 /* InferenceExample */, 8DCF4C422B99289E00427D77 /* Products */, 540B7F154909C4C9EF376B57 /* Pods */, - 8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */, + 32B29128DD33338DDF219030 /* Frameworks */, ); sourceTree = ""; }; @@ -93,6 +109,7 @@ isa = PBXNativeTarget; buildConfigurationList = 8DCF4C502B99289E00427D77 /* Build configuration list for PBXNativeTarget "InferenceExample" */; buildPhases = ( + 3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */, 8DCF4C3D2B99289E00427D77 /* Sources */, 8DCF4C3E2B99289E00427D77 /* Frameworks */, 8DCF4C3F2B99289E00427D77 /* Resources */, @@ -146,12 +163,37 @@ files = ( 8DCF4C4D2B99289E00427D77 /* Preview Assets.xcassets in Resources */, 8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */, - 8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */, + 0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */, ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXResourcesBuildPhase section */ +/* Begin PBXShellScriptBuildPhase section */ + 3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-InferenceExample-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + /* Begin PBXSourcesBuildPhase section */ 8DCF4C3D2B99289E00427D77 /* Sources */ = { isa = PBXSourcesBuildPhase; @@ -283,6 +325,7 @@ }; 8DCF4C512B99289E00427D77 /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; @@ -290,6 +333,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\""; + DEVELOPMENT_TEAM = M3D535GFVK; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; @@ -307,7 +351,7 @@ "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.2; MARKETING_VERSION = 1.0; - PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample; + PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; @@ -319,6 +363,7 @@ }; 8DCF4C522B99289E00427D77 /* Release */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; @@ -326,6 +371,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\""; + DEVELOPMENT_TEAM = M3D535GFVK; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; @@ -343,7 +389,7 @@ "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.2; MARKETING_VERSION = 1.0; - PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample; + PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; diff --git a/examples/llm_inference/ios/InferenceExample/ConversationViewModel.swift b/examples/llm_inference/ios/InferenceExample/ConversationViewModel.swift index 52452d38..579f138f 100644 --- a/examples/llm_inference/ios/InferenceExample/ConversationViewModel.swift +++ b/examples/llm_inference/ios/InferenceExample/ConversationViewModel.swift @@ -89,7 +89,12 @@ class ConversationViewModel: ObservableObject { messages.append(systemMessage) do { - let response = try await chat.sendMessage(text) + let response = try await chat.sendMessage(text, progress : { [weak self] partialResult in + guard let self = self else { return } + DispatchQueue.main.async { + self.messages[self.messages.count - 1].message = partialResult + } + }) // replace pending message with model response messages[messages.count - 1].message = response diff --git a/examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift b/examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift index fdc05efa..d06e1f0a 100644 --- a/examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift +++ b/examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift @@ -17,14 +17,27 @@ import MediaPipeTasksGenAI final class OnDeviceModel { - private var inference: LlmInference! = { - let path = Bundle.main.path(forResource: "gemma-2b-it-cpu-int4", ofType: "bin")! - let llmOptions = LlmInference.Options(modelPath: path) - return LlmInference(options: llmOptions) - }() - func generateResponse(prompt: String) async throws -> String { + private var cachedInference: LlmInference? + + private var inference: LlmInference + { + get throws { + if let cached = cachedInference { + return cached + } else { + let path = Bundle.main.path(forResource: "gemma-1.1-2b-it-gpu-int4", ofType: "bin")! + let llmOptions = LlmInference.Options(modelPath: path) + cachedInference = try LlmInference(options: llmOptions) + return cachedInference! + } + } + } + + func generateResponse(prompt: String, progress: @escaping (String) -> Void) async throws -> String { var partialResult = "" + + let inference = try inference return try await withCheckedThrowingContinuation { continuation in do { try inference.generateResponseAsync(inputText: prompt) { partialResponse, error in @@ -34,6 +47,7 @@ final class OnDeviceModel { } if let partial = partialResponse { partialResult += partial + progress(partialResult.trimmingCharacters(in: .whitespacesAndNewlines)) } } completion: { let aggregate = partialResult.trimmingCharacters(in: .whitespacesAndNewlines) @@ -62,16 +76,24 @@ final class Chat { self.model = model } + private func composeUserTurn(_ newMessage: String) -> String { + return "user\n\(newMessage)\n" + } + + private func composeModelTurn(_ newMessage: String) -> String { + return "model\n\(newMessage)\n" + } + private func compositePrompt(newMessage: String) -> String { return history.joined(separator: "\n") + "\n" + newMessage } - func sendMessage(_ text: String) async throws -> String { - let prompt = compositePrompt(newMessage: text) - let reply = try await model.generateResponse(prompt: prompt) - history.append(text) - history.append(reply) + func sendMessage(_ text: String, progress: @escaping (String) -> Void) async throws -> String { + let prompt = compositePrompt(newMessage: composeUserTurn(text)) + let reply = try await model.generateResponse(prompt: prompt, progress: progress) + + history = [prompt, composeModelTurn(reply)] print("Prompt: \(prompt)") print("Reply: \(reply)") diff --git a/examples/llm_inference/ios/Podfile.lock b/examples/llm_inference/ios/Podfile.lock index d25c9dc5..1cfa7ea9 100644 --- a/examples/llm_inference/ios/Podfile.lock +++ b/examples/llm_inference/ios/Podfile.lock @@ -1,7 +1,7 @@ PODS: - - MediaPipeTasksGenAI (0.10.11): - - MediaPipeTasksGenAIC (= 0.10.11) - - MediaPipeTasksGenAIC (0.10.11) + - MediaPipeTasksGenAI (0.10.14): + - MediaPipeTasksGenAIC (= 0.10.14) + - MediaPipeTasksGenAIC (0.10.14) DEPENDENCIES: - MediaPipeTasksGenAI @@ -12,8 +12,8 @@ SPEC REPOS: - MediaPipeTasksGenAIC SPEC CHECKSUMS: - MediaPipeTasksGenAI: 9fb3fc0e9f9329d0b3f89d741dbdb4cb4429b87a - MediaPipeTasksGenAIC: 9bb1f037b742d7d642c8b8fbec8b4626f73c18c5 + MediaPipeTasksGenAI: 8cd77fa32ea21f7a6319b025aa28cfc3e20ab73b + MediaPipeTasksGenAIC: 270ec81f85e96fac283945702e34112ebbfd5e77 PODFILE CHECKSUM: b561fe84c5e19b81e1111ba0f8f21564f7006b85