Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tool use #151

Merged
merged 5 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ let package = Package(
.executable(name: "hub-cli", targets: ["HubCLI"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0")
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0"))
],
targets: [
.executableTarget(
Expand Down
8 changes: 6 additions & 2 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ public extension LanguageModel {
var tokenizer: Tokenizer {
get async throws {
guard _tokenizer == nil else { return _tokenizer! }
guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
guard let tokenizerConfig = try await tokenizerConfig else {
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return _tokenizer!
Expand All @@ -212,4 +214,6 @@ extension LanguageModel: TextGenerationModel {
}
}

extension String: Error {}
public enum TokenizerError: Error {
case tokenizerConfigNotFound
}
102 changes: 88 additions & 14 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import Hub
import Foundation
import Jinja

public typealias Message = [String: Any]
public typealias ToolSpec = [String: Any]

enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
Expand Down Expand Up @@ -142,23 +145,57 @@ public protocol Tokenizer {
var unknownTokenId: Int? { get }

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
func applyChatTemplate(messages: [Message]) throws -> [Int]

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]

/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
func applyChatTemplate(
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [ToolSpec]?
) throws -> [Int]

We keep the previous declaration in the protocol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also need something like this:

extension Tokenizer {
    /// Call previous signature for backwards compatibility
    func applyChatTemplate(
        messages: [Message],
        /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
        chatTemplate: ChatTemplateArgument?,
        addGenerationPrompt: Bool,
        truncation: Bool,
        maxLength: Int?,
        tools: [ToolSpec]?,
        additionalContext: [String: Any]?
    ) throws -> [Int] {
        if additionalContext == nil {
            try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
        } else {
            throw TokenizerError.chatTemplate("Not implemented")
        }
    }
}

It's a bit confusing, but the idea is: if you have written your own Tokenizer implementation in your project, you will have implemented your own version of applyChatTemplate with the previous signature. But now there's a new protocol requirement that takes a new argument (the additionalContext) that does not exist in your code, because you implemented the old version. So your code will not compile. With the extension above, we call the "old" implementation, so the code will still compile and work.

I don't think this affects our implementation in PreTrainedTokenizer below, everything should just keep working.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be easier if you make the change yourself. I'm not sure if I fully understand the problem.


func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [[String: Any]]?
tools: [ToolSpec]?
) throws -> [Int]

func applyChatTemplate(
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [ToolSpec]?,
additionalContext: [String: Any]?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This argument actually introduces a breaking API change for users that provide their own Tokenizer implementations. Do you think we could introduce a default implementation somehow? For example, we could keep the previous declaration, and create a default implementation of this method that calls the old one when additionalContext is nil.

Otherwise we can always introduce this as part of 0.2.0, but I'd rather add it as an incremental update so most people can use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I added an overload for the previous function signature.

) throws -> [Int]
}

extension Tokenizer {
/// Call previous signature for backwards compatibility
func applyChatTemplate(
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [ToolSpec]?,
additionalContext: [String: Any]?
) throws -> [Int] {
if additionalContext == nil {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
} else {
throw TokenizerError.chatTemplate("Not implemented")
}
}
}

public extension Tokenizer {
Expand Down Expand Up @@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
}

public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
-> [Int]
{
try applyChatTemplate(
messages: messages,
addGenerationPrompt: true,
tools: tools,
additionalContext: additionalContext
)
}

public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int? = nil,
tools: [ToolSpec]? = nil
) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil)
}

public func applyChatTemplate(
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
Expand All @@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer {
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) throws -> [Int] {
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
Expand Down Expand Up @@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer {
let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
"add_generation_prompt": addGenerationPrompt,
]
if let tools {
context["tools"] = tools
}
if let additionalContext {
/*
Additional keys and values to be added to the context provided to the prompt templating engine.
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
*/
for (key, value) in additionalContext {
context[key] = value
}
}

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
Expand Down
90 changes: 89 additions & 1 deletion Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,93 @@ class ChatTemplateTests: XCTestCase {
XCTAssertEqual(decoded, decodedTarget)
}

// TODO: Add tests for tool use template
func testQwen2_5WithTools() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-7B-Instruct-4bit")

let weatherQueryMessages: [[String: String]] = [
[
"role": "user",
"content": "What is the weather in Paris today?",
]
]

let getCurrentWeatherToolSpec: [String: Any] = [
"type": "function",
"function": [
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": [
"type": "object",
"properties": [
"location": [
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
],
"unit": [
"type": "string",
"enum": ["celsius", "fahrenheit"]
]
],
"required": ["location"]
]
]
]

let encoded = try tokenizer.applyChatTemplate(messages: weatherQueryMessages, tools: [getCurrentWeatherToolSpec])
let decoded = tokenizer.decode(tokens: encoded)

func assertDictsAreEqual(_ actual: [String: Any], _ expected: [String: Any]) {
for (key, value) in actual {
if let nestedDict = value as? [String: Any], let nestedDict2 = expected[key] as? [String: Any] {
assertDictsAreEqual(nestedDict, nestedDict2)
} else if let arrayValue = value as? [String] {
let expectedArrayValue = expected[key] as? [String]
XCTAssertNotNil(expectedArrayValue)
XCTAssertEqual(Set(arrayValue), Set(expectedArrayValue!))
} else {
XCTAssertEqual(value as? String, expected[key] as? String)
}
}
}

if let startRange = decoded.range(of: "<tools>\n"),
let endRange = decoded.range(of: "\n</tools>", range: startRange.upperBound..<decoded.endIndex) {
let toolsSection = String(decoded[startRange.upperBound..<endRange.lowerBound])
if let toolsDict = try? JSONSerialization.jsonObject(with: toolsSection.data(using: .utf8)!) as? [String : Any] {
assertDictsAreEqual(toolsDict, getCurrentWeatherToolSpec)
} else {
XCTFail("Failed to decode tools section")
}
} else {
XCTFail("Failed to find tools section")
}

let expectedPromptStart = """
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.

# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
"""

let expectedPromptEnd = """
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call><|im_end|>
<|im_start|>user
What is the weather in Paris today?<|im_end|>
<|im_start|>assistant

"""

XCTAssertTrue(decoded.hasPrefix(expectedPromptStart), "Prompt should start with expected system message")
XCTAssertTrue(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format")
}
}
4 changes: 3 additions & 1 deletion Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ class TokenizerTester {
get async {
guard _tokenizer == nil else { return _tokenizer! }
do {
guard let tokenizerConfig = try await configuration!.tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await configuration!.tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
} catch {
Expand Down