-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable tool use #151
Enable tool use #151
Changes from all commits
d2aeebf
96960da
453b2be
e093fca
cca036c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,9 @@ import Hub | |
import Foundation | ||
import Jinja | ||
|
||
public typealias Message = [String: Any] | ||
public typealias ToolSpec = [String: Any] | ||
|
||
enum TokenizerError: Error { | ||
case missingConfig | ||
case missingTokenizerClassInConfig | ||
|
@@ -142,23 +145,57 @@ public protocol Tokenizer { | |
var unknownTokenId: Int? { get } | ||
|
||
/// The appropriate chat template is selected from the tokenizer config | ||
func applyChatTemplate(messages: [[String: String]]) throws -> [Int] | ||
func applyChatTemplate(messages: [Message]) throws -> [Int] | ||
|
||
/// The appropriate chat template is selected from the tokenizer config | ||
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] | ||
|
||
/// The chat template is provided as a string literal or specified by name | ||
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] | ||
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] | ||
|
||
/// The chat template is provided as a string literal | ||
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] | ||
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] | ||
|
||
func applyChatTemplate( | ||
messages: [[String: String]], | ||
messages: [Message], | ||
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. | ||
chatTemplate: ChatTemplateArgument?, | ||
addGenerationPrompt: Bool, | ||
truncation: Bool, | ||
maxLength: Int?, | ||
tools: [[String: Any]]? | ||
tools: [ToolSpec]? | ||
) throws -> [Int] | ||
|
||
func applyChatTemplate( | ||
messages: [Message], | ||
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. | ||
chatTemplate: ChatTemplateArgument?, | ||
addGenerationPrompt: Bool, | ||
truncation: Bool, | ||
maxLength: Int?, | ||
tools: [ToolSpec]?, | ||
additionalContext: [String: Any]? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This argument actually introduces a breaking API change for users that provide their own Otherwise we can always introduce this as part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I added an overload for the previous function signature. |
||
) throws -> [Int] | ||
} | ||
|
||
extension Tokenizer { | ||
/// Call previous signature for backwards compatibility | ||
func applyChatTemplate( | ||
messages: [Message], | ||
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. | ||
chatTemplate: ChatTemplateArgument?, | ||
addGenerationPrompt: Bool, | ||
truncation: Bool, | ||
maxLength: Int?, | ||
tools: [ToolSpec]?, | ||
additionalContext: [String: Any]? | ||
) throws -> [Int] { | ||
if additionalContext == nil { | ||
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools) | ||
} else { | ||
throw TokenizerError.chatTemplate("Not implemented") | ||
} | ||
} | ||
} | ||
|
||
public extension Tokenizer { | ||
|
@@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer { | |
model.convertIdToToken(id) | ||
} | ||
|
||
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] { | ||
public func applyChatTemplate(messages: [Message]) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, addGenerationPrompt: true) | ||
} | ||
|
||
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] { | ||
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools) | ||
} | ||
|
||
public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws | ||
-> [Int] | ||
{ | ||
try applyChatTemplate( | ||
messages: messages, | ||
addGenerationPrompt: true, | ||
tools: tools, | ||
additionalContext: additionalContext | ||
) | ||
} | ||
|
||
public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true) | ||
} | ||
|
||
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] { | ||
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) | ||
} | ||
|
||
public func applyChatTemplate( | ||
messages: [[String: String]], | ||
messages: [Message], | ||
chatTemplate: ChatTemplateArgument? = nil, | ||
addGenerationPrompt: Bool = false, | ||
truncation: Bool = false, | ||
maxLength: Int? = nil, | ||
tools: [ToolSpec]? = nil | ||
) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil) | ||
} | ||
|
||
public func applyChatTemplate( | ||
messages: [Message], | ||
chatTemplate: ChatTemplateArgument? = nil, | ||
addGenerationPrompt: Bool = false, | ||
truncation: Bool = false, | ||
|
@@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer { | |
/// giving the name, description and argument types for the tool. See the | ||
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) | ||
/// for more information. | ||
/// Note: tool calling is not supported yet, it will be available in a future update. | ||
tools: [[String: Any]]? = nil | ||
tools: [ToolSpec]? = nil, | ||
additionalContext: [String: Any]? = nil | ||
) throws -> [Int] { | ||
var selectedChatTemplate: String? | ||
if let chatTemplate, case .literal(let template) = chatTemplate { | ||
|
@@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer { | |
let template = try Template(selectedChatTemplate) | ||
var context: [String: Any] = [ | ||
"messages": messages, | ||
"add_generation_prompt": addGenerationPrompt | ||
// TODO: Add `tools` entry when support is added in Jinja | ||
// "tools": tools | ||
"add_generation_prompt": addGenerationPrompt, | ||
] | ||
if let tools { | ||
context["tools"] = tools | ||
} | ||
if let additionalContext { | ||
/* | ||
Additional keys and values to be added to the context provided to the prompt templating engine. | ||
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided. | ||
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message. | ||
*/ | ||
for (key, value) in additionalContext { | ||
context[key] = value | ||
} | ||
} | ||
|
||
// TODO: maybe keep NSString here | ||
for (key, value) in tokenizerConfig.dictionary as [String : Any] { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We keep the previous declaration in the protocol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we also need something like this:
It's a bit confusing, but the idea is: if you have written your own
Tokenizer
implementation in your project, you will have implemented your own version ofapplyChatTemplate
with the previous signature. But now there's a new protocol requirement that takes a new argument (theadditionalContext
) that does not exist in your code, because you implemented the old version. So your code will not compile. With the extension above, we call the "old" implementation, so the code will still compile and work.I don't think this affects our implementation in
PreTrainedTokenizer
below, everything should just keep working.What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be easier if you make the change yourself. I'm not sure if I fully understand the problem.