Skip to content

Commit

Permalink
Enable tools
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Jan 19, 2025
1 parent fd16c00 commit 6b941a9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
6 changes: 4 additions & 2 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ extension LanguageModel {
get async throws {
guard _tokenizer == nil else { return _tokenizer! }
guard let tokenizerConfig = try await tokenizerConfig else {
throw "Cannot retrieve Tokenizer configuration"
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
Expand All @@ -218,4 +218,6 @@ extension LanguageModel: TextGenerationModel {
}
}

extension String: Error {}
public enum TokenizerError: Error {
case tokenizerConfigNotFound
}
59 changes: 46 additions & 13 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import Foundation
import Hub
import Jinja

public typealias Message = [String: Any]
public typealias ToolSpec = [String: Any]

enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
Expand Down Expand Up @@ -133,22 +136,26 @@ public protocol Tokenizer {
var unknownTokenId: Int? { get }

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
func applyChatTemplate(messages: [Message]) throws -> [Int]

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]

/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [[String: Any]]?
tools: [ToolSpec]?,
additionalContext: [String: Any]?
) throws -> [Int]
}

Expand Down Expand Up @@ -356,20 +363,35 @@ public class PreTrainedTokenizer: Tokenizer {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
}

public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
-> [Int]
{
try applyChatTemplate(
messages: messages,
addGenerationPrompt: true,
tools: tools,
additionalContext: additionalContext
)
}

public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
Expand All @@ -379,8 +401,8 @@ public class PreTrainedTokenizer: Tokenizer {
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) throws -> [Int] {
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
Expand Down Expand Up @@ -425,9 +447,20 @@ public class PreTrainedTokenizer: Tokenizer {
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt,
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
]
if let tools {
context["tools"] = tools
}
if let additionalContext {
/*
Additional keys and values to be added to the context provided to the prompt templating engine.
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
*/
for (key, value) in additionalContext {
context[key] = value
}
}

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String: Any] {
Expand Down
2 changes: 1 addition & 1 deletion Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class TokenizerTester {
guard _tokenizer == nil else { return _tokenizer! }
do {
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
throw "Cannot retrieve Tokenizer configuration"
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await configuration!.tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
Expand Down

0 comments on commit 6b941a9

Please sign in to comment.