From d2aeebfdd31e71a28012a29739c775c698d64b49 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 1 Jan 2025 19:00:45 +0100 Subject: [PATCH 1/5] Enable tools --- Sources/Models/LanguageModel.swift | 8 ++- Sources/Tokenizers/Tokenizer.swift | 61 +++++++++++++++++----- Tests/TokenizersTests/TokenizerTests.swift | 4 +- 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 457755a..22ba7aa 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -190,7 +190,9 @@ public extension LanguageModel { var tokenizer: Tokenizer { get async throws { guard _tokenizer == nil else { return _tokenizer! } - guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await tokenizerConfig else { + throw TokenizerError.tokenizerConfigNotFound + } let tokenizerData = try await tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) return _tokenizer! @@ -212,4 +214,6 @@ extension LanguageModel: TextGenerationModel { } } -extension String: Error {} +public enum TokenizerError: Error { + case tokenizerConfigNotFound +} diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index cabfd76..fef5ccd 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -9,6 +9,9 @@ import Hub import Foundation import Jinja +public typealias Message = [String: Any] +public typealias ToolSpec = [String: Any] + enum TokenizerError: Error { case missingConfig case missingTokenizerClassInConfig @@ -142,22 +145,26 @@ public protocol Tokenizer { var unknownTokenId: Int? { get } /// The appropriate chat template is selected from the tokenizer config - func applyChatTemplate(messages: [[String: String]]) throws -> [Int] + func applyChatTemplate(messages: [Message]) throws -> [Int] + + /// The appropriate chat template is selected from the tokenizer config + func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] /// The chat template is provided as a string literal or specified by name - func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] + func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] /// The chat template is provided as a string literal - func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] + func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] func applyChatTemplate( - messages: [[String: String]], + messages: [Message], /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, maxLength: Int?, - tools: [[String: Any]]? + tools: [ToolSpec]?, + additionalContext: [String: Any]? ) throws -> [Int] } @@ -359,20 +366,35 @@ public class PreTrainedTokenizer: Tokenizer { model.convertIdToToken(id) } - public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] { + public func applyChatTemplate(messages: [Message]) throws -> [Int] { try applyChatTemplate(messages: messages, addGenerationPrompt: true) } - public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] { + public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] { + try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools) + } + + public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws + -> [Int] + { + try applyChatTemplate( + messages: messages, + addGenerationPrompt: true, + tools: tools, + additionalContext: additionalContext + ) + } + + public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] { try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true) } - public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] { + public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] { try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) } public func applyChatTemplate( - messages: [[String: String]], + messages: [Message], chatTemplate: ChatTemplateArgument? = nil, addGenerationPrompt: Bool = false, truncation: Bool = false, @@ -382,8 +404,8 @@ public class PreTrainedTokenizer: Tokenizer { /// giving the name, description and argument types for the tool. See the /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) /// for more information. - /// Note: tool calling is not supported yet, it will be available in a future update. - tools: [[String: Any]]? = nil + tools: [ToolSpec]? = nil, + additionalContext: [String: Any]? = nil ) throws -> [Int] { var selectedChatTemplate: String? if let chatTemplate, case .literal(let template) = chatTemplate { @@ -425,10 +447,21 @@ public class PreTrainedTokenizer: Tokenizer { let template = try Template(selectedChatTemplate) var context: [String: Any] = [ "messages": messages, - "add_generation_prompt": addGenerationPrompt - // TODO: Add `tools` entry when support is added in Jinja - // "tools": tools + "add_generation_prompt": addGenerationPrompt, ] + if let tools { + context["tools"] = tools + } + if let additionalContext { + /* + Additional keys and values to be added to the context provided to the prompt templating engine. + For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided. + The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message. + */ + for (key, value) in additionalContext { + context[key] = value + } + } // TODO: maybe keep NSString here for (key, value) in tokenizerConfig.dictionary as [String : Any] { diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index 0efd60d..2a09424 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -259,7 +259,9 @@ class TokenizerTester { get async { guard _tokenizer == nil else { return _tokenizer! } do { - guard let tokenizerConfig = try await configuration!.tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await configuration!.tokenizerConfig else { + throw TokenizerError.tokenizerConfigNotFound + } let tokenizerData = try await configuration!.tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } catch { From 96960da51a17e22289b158384b1c1ba502f162ec Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sun, 26 Jan 2025 20:42:58 +0100 Subject: [PATCH 2/5] Add tool use test --- Tests/TokenizersTests/ChatTemplateTests.swift | 90 ++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index 85b374d..dfadc26 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -80,5 +80,93 @@ class ChatTemplateTests: XCTestCase { XCTAssertEqual(decoded, decodedTarget) } - // TODO: Add tests for tool use template + func testQwen2_5WithTools() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-7B-Instruct-4bit") + + let weatherQueryMessages: [[String: String]] = [ + [ + "role": "user", + "content": "What is the weather in Paris today?", + ] + ] + + let getCurrentWeatherToolSpec: [String: Any] = [ + "type": "function", + "function": [ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + ], + "unit": [ + "type": "string", + "enum": ["celsius", "fahrenheit"] + ] + ], + "required": ["location"] + ] + ] + ] + + let encoded = try tokenizer.applyChatTemplate(messages: weatherQueryMessages, tools: [getCurrentWeatherToolSpec]) + let decoded = tokenizer.decode(tokens: encoded) + + func assertDictsAreEqual(_ actual: [String: Any], _ expected: [String: Any]) { + for (key, value) in actual { + if let nestedDict = value as? [String: Any], let nestedDict2 = expected[key] as? [String: Any] { + assertDictsAreEqual(nestedDict, nestedDict2) + } else if let arrayValue = value as? [String] { + let expectedArrayValue = expected[key] as? [String] + XCTAssertNotNil(expectedArrayValue) + XCTAssertEqual(Set(arrayValue), Set(expectedArrayValue!)) + } else { + XCTAssertEqual(value as? String, expected[key] as? String) + } + } + } + + if let startRange = decoded.range(of: "\n"), + let endRange = decoded.range(of: "\n", range: startRange.upperBound..system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +""" + + let expectedPromptEnd = """ + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What is the weather in Paris today?<|im_end|> +<|im_start|>assistant + +""" + + XCTAssertTrue(decoded.hasPrefix(expectedPromptStart), "Prompt should start with expected system message") + XCTAssertTrue(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format") + } } From 453b2bec365ba91a119dfdbb2f1f944faf335361 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Mon, 27 Jan 2025 15:13:11 +0100 Subject: [PATCH 3/5] Use latest patch versions of dependencies --- Package.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Package.swift b/Package.swift index ac6a853..fc28be1 100644 --- a/Package.swift +++ b/Package.swift @@ -12,8 +12,8 @@ let package = Package( .executable(name: "hub-cli", targets: ["HubCLI"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"), - .package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0") + .package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")), + .package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")) ], targets: [ .executableTarget( From e093fcaa44acb8ba6481fadbe60786fdab930472 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 29 Jan 2025 19:19:28 +0100 Subject: [PATCH 4/5] Keep previous `applyChatTemplate` --- Sources/Tokenizers/Tokenizer.swift | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index fef5ccd..e8f0058 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -393,6 +393,17 @@ public class PreTrainedTokenizer: Tokenizer { try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) } + public func applyChatTemplate( + messages: [Message], + chatTemplate: ChatTemplateArgument? = nil, + addGenerationPrompt: Bool = false, + truncation: Bool = false, + maxLength: Int? = nil, + tools: [ToolSpec]? = nil + ) throws -> [Int] { + try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil) + } + public func applyChatTemplate( messages: [Message], chatTemplate: ChatTemplateArgument? = nil, From cca036c84b385e186fffdeb5890f32c0b3b5a0b6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Jan 2025 20:18:55 +0100 Subject: [PATCH 5/5] Backwards source code compatibility --- Sources/Tokenizers/Tokenizer.swift | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index e8f0058..caa5fd6 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -156,6 +156,16 @@ public protocol Tokenizer { /// The chat template is provided as a string literal func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] + func applyChatTemplate( + messages: [Message], + /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + chatTemplate: ChatTemplateArgument?, + addGenerationPrompt: Bool, + truncation: Bool, + maxLength: Int?, + tools: [ToolSpec]? + ) throws -> [Int] + func applyChatTemplate( messages: [Message], /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. @@ -168,6 +178,26 @@ public protocol Tokenizer { ) throws -> [Int] } +extension Tokenizer { + /// Call previous signature for backwards compatibility + func applyChatTemplate( + messages: [Message], + /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + chatTemplate: ChatTemplateArgument?, + addGenerationPrompt: Bool, + truncation: Bool, + maxLength: Int?, + tools: [ToolSpec]?, + additionalContext: [String: Any]? + ) throws -> [Int] { + if additionalContext == nil { + try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools) + } else { + throw TokenizerError.chatTemplate("Not implemented") + } + } +} + public extension Tokenizer { func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] { encode(text: text, addSpecialTokens: addSpecialTokens)