Skip to content

Commit

Permalink
Enable tools
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Jan 2, 2025
1 parent 92b5072 commit be1f482
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
6 changes: 4 additions & 2 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public extension LanguageModel {
get async throws {
guard _tokenizer == nil else { return _tokenizer! }
guard let tokenizerConfig = try await tokenizerConfig else {
throw "Cannot retrieve Tokenizer configuration"
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
Expand All @@ -218,4 +218,6 @@ extension LanguageModel: TextGenerationModel {
}
}

extension String: Error {}
public enum TokenizerError: Error {
case tokenizerConfigNotFound
}
38 changes: 25 additions & 13 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import Hub
import Foundation
import Jinja

public typealias Message = [String: Any]
public typealias ToolSpec = [String: Any]

enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
Expand Down Expand Up @@ -134,22 +137,25 @@ public protocol Tokenizer {
var unknownTokenId: Int? { get }

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
func applyChatTemplate(messages: [Message]) throws -> [Int]

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]

/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [[String: Any]]?
tools: [ToolSpec]?
) throws -> [Int]
}

Expand Down Expand Up @@ -358,20 +364,24 @@ public class PreTrainedTokenizer: Tokenizer {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
}

public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
Expand All @@ -381,8 +391,7 @@ public class PreTrainedTokenizer: Tokenizer {
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
tools: [ToolSpec]? = nil
) throws -> [Int] {
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
Expand Down Expand Up @@ -429,9 +438,12 @@ public class PreTrainedTokenizer: Tokenizer {
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt,
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
]
if let tools {
context["tools"] = tools
// Performance might be better if the tools prompt is included in a system message rather than a user message, but then the system message must be present.
context["tools_in_user_message"] = false // Default is true in Llama 3.1 and 3.2 template
}

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String: Any] {
Expand Down
2 changes: 1 addition & 1 deletion Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class TokenizerTester {
guard _tokenizer == nil else { return _tokenizer! }
do {
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
throw "Cannot retrieve Tokenizer configuration"
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await configuration!.tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
Expand Down

0 comments on commit be1f482

Please sign in to comment.