Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Instruct/Chat versions of models & introduce a new ChatTemplate API, fix Anthropic API #820

Merged
merged 34 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
363e73c
Undo some of the more egregious formatting changes made by black
Harsha-Nori May 12, 2024
15183a3
Refactor chat template representation and introduce a cache for popul…
Harsha-Nori May 12, 2024
1dcf00a
Updating chat template logic.
Harsha-Nori May 12, 2024
2818f8b
Change internal ChatTemplateCache to be a singleton instead of class …
Harsha-Nori May 12, 2024
fab8d14
Begin process of removing Chat versions of Transformers models.
Harsha-Nori May 12, 2024
7c11167
Add llamacpp support and fix bugs in transformers models. Should supp…
Harsha-Nori May 13, 2024
cc7c4d1
Remove references to separate chat based local models.
Harsha-Nori May 13, 2024
23fe8f3
Remove llama subclass of Transformers as this now works out of the box.
Harsha-Nori May 13, 2024
b792490
Also strip newly deleted classes from __init__ py files.
Harsha-Nori May 13, 2024
3e9b20b
Refactor Anthropic models to use their new API, and add support for n…
Harsha-Nori May 13, 2024
b9ae95c
Fix Phi-3 in chat mode. Add tests for some future debugging we need t…
Harsha-Nori May 13, 2024
80068b6
bugfix for llama3 tokenizer.
Harsha-Nori May 14, 2024
1dd2f43
change chat_template_cache test to use env var for github ci
Harsha-Nori May 14, 2024
f2ea53b
refactor test_chat_templates to use pytest.mark.parametrize
Harsha-Nori May 14, 2024
410a17e
skip tests on bigger models used for local debugging that wont fit on…
Harsha-Nori May 14, 2024
56a8e93
add exception handling if someone passes a llamacpptokenizer to remot…
Harsha-Nori May 14, 2024
98846d7
openAI now works with new interface, azureopenai and togetherai need …
Harsha-Nori May 14, 2024
ac931b6
WIP support for AzureOpenAI (which is currently down so I can't test it)
Harsha-Nori May 14, 2024
9232667
remove references to deleted classes.
Harsha-Nori May 14, 2024
b230435
Update ci_tests.yml
Harsha-Nori May 14, 2024
b5f9a0a
bugfix for mistral7b chattemplatecache
Harsha-Nori May 14, 2024
c1bb15c
Merge branch 'nochat' of github.com:guidance-ai/guidance into nochat
Harsha-Nori May 14, 2024
be7e955
Merge remote-tracking branch 'upstream/main' into nochat
riedgar-ms May 14, 2024
2fb6fc0
Copy/paste fix
riedgar-ms May 14, 2024
19a3547
Refactor generator
riedgar-ms May 14, 2024
d29a853
Update class checks
riedgar-ms May 14, 2024
b01b95a
Test fixing
riedgar-ms May 14, 2024
bdd56c7
Want to look at env variables
riedgar-ms May 14, 2024
18f4109
Extra completion model
riedgar-ms May 14, 2024
7859636
Fix name
riedgar-ms May 14, 2024
8b54aa8
Credentials required
riedgar-ms May 14, 2024
57d1c2b
Some more test work
riedgar-ms May 14, 2024
5d6c660
add a tokenizer to the RemoteEngine class to support a default chat_t…
paulbkoch May 14, 2024
d6c0b38
Missing condition
riedgar-ms May 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
python -c "import torch; assert torch.cuda.is_available()"
- name: Test with pytest
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# Configure endpoints for Azure OpenAI
AZUREAI_CHAT_ENDPOINT: ${{ secrets.AZUREAI_CHAT_ENDPOINT }}
AZUREAI_CHAT_KEY: ${{ secrets.AZUREAI_CHAT_KEY }}
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ guidance/_rust/Cargo.lock

notebooks/**/*.papermill_out.ipynb

.mypy_cache/*
.mypy_cache/*

**/scratch.*
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ else:
```python
from guidance import user, assistant

# load a chat model
chat_lm = models.LlamaCppChat(path)
# load a model
chat_lm = models.LlamaCpp(path)

# wrap with chat block contexts
with user():
Expand Down
210 changes: 210 additions & 0 deletions guidance/_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import warnings
import uuid
import inspect

class ChatTemplate:
"""Contains template for all chat and instruct tuned models."""

def get_role_start(self, role_name, **kwargs):
raise NotImplementedError(
"You need to use a ChatTemplate subclass that overrides the get_role_start method"
)

def get_role_end(self, role_name=None):
raise NotImplementedError(
"You need to use a ChatTemplate subclass that overrides the get_role_start method"
)

class ChatTemplateCache:
def __init__(self):
self._cache = {}

def __getitem__(self, key):
key_compact = key.replace(" ", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point: will this collapse multiple consecutive spaces?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah it's a good catch, I'm actually not sure we need to be doing this. I didn't want minor differences in jinja formats (which I believe are whitespace agnostic for parts of them) to cause different mappings in the cache, but maybe there are places where an extra space is actually a meaningful difference?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a question I asked before I caught sight of the actual keys you were using....

return self._cache[key_compact]


def __setitem__(self, key, value):
key_compact = key.replace(" ", "")
self._cache[key_compact] = value

def __contains__(self, key):
key_compact = key.replace(" ", "")
return key_compact in self._cache

# Feels weird having to instantiate this, but it's a singleton for all purposes
# TODO [HN]: Add an alias system so we can instantiate with other simple keys (e.g. "llama2" instead of the full template string)
CHAT_TEMPLATE_CACHE = ChatTemplateCache()

class UnsupportedRoleException(Exception):
def __init__(self, role_name, instance):
self.role_name = role_name
self.instance = instance
super().__init__(self._format_message())

def _format_message(self):
return (f"Role {self.role_name} is not supported by the {self.instance.__class__.__name__} chat template. ")

def load_template_class(chat_template=None):
"""Utility method to find the best chat template.

Order of precedence:
- If it's a chat template class, use it directly
- If it's a string, check the cache of popular model templates
- If it's a string and not in the cache, try to create a class dynamically
- [TODO] If it's a string and can't be created, default to ChatML and raise a warning
- If it's None, default to ChatML and raise a warning
"""
if inspect.isclass(chat_template) and issubclass(chat_template, ChatTemplate):
if chat_template is ChatTemplate:
raise Exception("You can't use the base ChatTemplate class directly. Create or use a subclass instead.")
return chat_template

elif isinstance(chat_template, str):
# First check the cache of popular model types
# TODO: Expand keys of cache to include aliases for popular model types (e.g. "llama2, phi3")
# Can possibly accomplish this with an "aliases" dictionary that maps all aliases to the canonical key in cache
if chat_template in CHAT_TEMPLATE_CACHE:
return CHAT_TEMPLATE_CACHE[chat_template]
# TODO: Add logic here to try to auto-create class dynamically via _template_class_from_string method

# Only warn when a user provided a chat template that we couldn't load
if chat_template is not None:
warnings.warn(f"""Chat template {chat_template} was unable to be loaded directly into guidance.
Defaulting to the ChatML format which may not be optimal for the selected model.
For best results, create and pass in a `guidance.ChatTemplate` subclass for your model.""")

# By default, use the ChatML Template. Warnings to user will happen downstream only if they use chat roles.
return ChatMLTemplate


def _template_class_from_string(template_str):
"""Utility method to try to create a chat template class from a string."""
# TODO: Try to build this, perhaps based on passing unit tests we create?
pass


# CACHE IMPLEMENTATIONS:

# --------------------------------------------------
# @@@@ ChatML @@@@
# --------------------------------------------------
# Note that all grammarless models will default to this syntax, since we typically send chat formatted messages.
chatml_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
Harsha-Nori marked this conversation as resolved.
Show resolved Hide resolved
class ChatMLTemplate(ChatTemplate):
template_str = chatml_template

def get_role_start(self, role_name):
return f"<|im_start|>{role_name}\n"

def get_role_end(self, role_name=None):
return "<|im_end|>\n"

CHAT_TEMPLATE_CACHE[chatml_template] = ChatMLTemplate


# --------------------------------------------------
# @@@@ Llama-2 @@@@
# --------------------------------------------------
# [05/08/24] https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/tokenizer_config.json#L12
llama2_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved
class Llama2ChatTemplate(ChatTemplate):
# available_roles = ["system", "user", "assistant"]
template_str = llama2_template

def get_role_start(self, role_name):
if role_name == "system":
return "[INST] <<SYS>>\n"
elif role_name == "user":
return "<s>[INST]"
elif role_name == "assistant":
return " "
else:
raise UnsupportedRoleException(role_name, self)

def get_role_end(self, role_name=None):
if role_name == "system":
return "\n<</SYS>"
elif role_name == "user":
return " [/INST]"
elif role_name == "assistant":
return "</s>"
else:
raise UnsupportedRoleException(role_name, self)
Comment on lines +111 to +133
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this template is likely an oversimplification of the hf template string, I'll need to debug more and extend this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to add the tests you develop during your debugging.


CHAT_TEMPLATE_CACHE[llama2_template] = Llama2ChatTemplate


# --------------------------------------------------
# @@@@ Llama-3 @@@@
# --------------------------------------------------
# [05/08/24] https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json#L2053
llama3_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
class Llama3ChatTemplate(ChatTemplate):
# available_roles = ["system", "user", "assistant"]
template_str = llama3_template

def get_role_start(self, role_name):
if role_name == "system":
return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
elif role_name == "user":
return "<|start_header_id|>user<|end_header_id>\n\n"
elif role_name == "assistant":
return "<|start_header_id|>assistant<|end_header_id>\n\n"
else:
raise UnsupportedRoleException(role_name, self)

def get_role_end(self, role_name=None):
return "<|eot_id|>"

CHAT_TEMPLATE_CACHE[llama3_template] = Llama3ChatTemplate

# --------------------------------------------------
# @@@@ Phi-3 @@@@
# --------------------------------------------------
# [05/08/24] https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json#L119
phi3_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
class Phi3ChatTemplate(ChatTemplate):
# available_roles = ["user", "assistant"]
template_str = phi3_template

def get_role_start(self, role_name):
if role_name == "user":
return "<|user|>"
elif role_name == "assistant":
return "<|assistant|>"
else:
raise UnsupportedRoleException(role_name, self)

def get_role_end(self, role_name=None):
return "<|end|>"

CHAT_TEMPLATE_CACHE[phi3_template] = Phi3ChatTemplate


# --------------------------------------------------
# @@@@ Mistral-7B-Instruct-v0.2 @@@@
# --------------------------------------------------
# [05/08/24] https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json#L42
mistral_7b_instruct_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
class Mistral7BInstructChatTemplate(ChatTemplate):
# available_roles = ["user", "assistant"]
template_str = mistral_7b_instruct_template

def get_role_start(self, role_name):
if role_name == "user":
return "[INST] "
elif role_name == "assistant":
return ""
else:
raise UnsupportedRoleException(role_name, self)

def get_role_end(self, role_name=None):
if role_name == "user":
return " [/INST]"
elif role_name == "assistant":
return "</s>"
else:
raise UnsupportedRoleException(role_name, self)

CHAT_TEMPLATE_CACHE[mistral_7b_instruct_template] = Mistral7BInstructChatTemplate
28 changes: 19 additions & 9 deletions guidance/library/_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
span_start = "<||_html:<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>_||>"
span_end = "<||_html:</span>_||>"


@guidance
def role_opener(lm, role_name, **kwargs):
indent = getattr(lm, "indent_roles", True)
if not hasattr(lm, "get_role_start"):
raise Exception(
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
)


# Block start container (centers elements)
if indent:
Expand All @@ -25,8 +21,17 @@ def role_opener(lm, role_name, **kwargs):
lm += nodisp_start
else:
lm += span_start

lm += lm.get_role_start(role_name, **kwargs)

# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still the case (and below)?

# Eventually have all models use chat_template.
if hasattr(lm, "get_role_start"):
lm += lm.get_role_start(role_name, **kwargs)
elif hasattr(lm, "chat_template"):
lm += lm.chat_template.get_role_start(role_name)
else:
raise Exception(
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
)

# End of either debug or HTML no disp block
if indent:
Expand All @@ -46,7 +51,12 @@ def role_closer(lm, role_name, **kwargs):
else:
lm += span_start

lm += lm.get_role_end(role_name)
# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
# Eventually have all models use chat_template.
if hasattr(lm, "get_role_end"):
lm += lm.get_role_end(role_name)
elif hasattr(lm, "chat_template"):
lm += lm.chat_template.get_role_end(role_name)

# End of either debug or HTML no disp block
if indent:
Expand All @@ -60,7 +70,7 @@ def role_closer(lm, role_name, **kwargs):

return lm


# TODO HN: Add a docstring to better describe arbitrary role functions
def role(role_name, text=None, **kwargs):
if text is None:
return block(
Expand Down
11 changes: 4 additions & 7 deletions guidance/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ._model import Model, Instruct, Chat

# local models
from .transformers._transformers import Transformers, TransformersChat
from .llama_cpp import LlamaCpp, LlamaCppChat, MistralInstruct, MistralChat
from .transformers._transformers import Transformers
from .llama_cpp import LlamaCpp
from ._mock import Mock, MockChat

# grammarless models (we can't do constrained decoding for them)
Expand All @@ -15,15 +15,12 @@
)
from ._azure_openai import (
AzureOpenAI,
AzureOpenAIChat,
AzureOpenAICompletion,
AzureOpenAIInstruct,
)
from ._azureai_studio import AzureAIStudioChat
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
from ._openai import OpenAI
from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
from ._cohere import Cohere, CohereCompletion, CohereInstruct
from ._anthropic import Anthropic, AnthropicChat
from ._anthropic import Anthropic
from ._googleai import GoogleAI, GoogleAIChat
from ._togetherai import (
TogetherAI,
Expand Down
Loading
Loading