Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for Azure AI Studio #779

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
386cdd3
Starting to think about what we need for AzureAI Studio
riedgar-ms Apr 26, 2024
176201c
Getting to the initially desired failure
riedgar-ms Apr 26, 2024
beac0cc
Very rough draft....
riedgar-ms Apr 26, 2024
3d90baa
Inching along
riedgar-ms Apr 26, 2024
32bc793
Trying to get things working :-/
riedgar-ms Apr 26, 2024
7840cfd
Didn't mean to check that in
riedgar-ms Apr 26, 2024
04e45c7
Erroneous addition
riedgar-ms Apr 26, 2024
bcc241a
Merge branch 'main' into riedgar-ms/azure-ai-studio-support-01
riedgar-ms Apr 29, 2024
25ecccf
Switch to requests
riedgar-ms Apr 29, 2024
1265346
Make sure that cache is unique to endpoint/deployment
riedgar-ms Apr 29, 2024
f348880
Starting to test mistral too.... not fully working yet
riedgar-ms Apr 29, 2024
0fc4727
Get the Mistral test working
riedgar-ms Apr 29, 2024
2be7f58
Add LLama3
riedgar-ms Apr 29, 2024
3bcb48e
Expand the endpoint configuration
riedgar-ms Apr 30, 2024
559b341
Merge remote-tracking branch 'upstream/main' into riedgar-ms/azure-ai…
riedgar-ms May 1, 2024
9e46101
Add option to clear cache on instaniating model
riedgar-ms May 1, 2024
0a7cc81
Some more experimenting
riedgar-ms May 1, 2024
1584d9f
Want some parallel Azure OpenAI tests
riedgar-ms May 1, 2024
60b23c8
Copy/paste error
riedgar-ms May 1, 2024
10fc9ba
Change test to passing
riedgar-ms May 1, 2024
b68e9d7
Expand Azure AI Studio testing
riedgar-ms May 1, 2024
9a4c1a8
Refactor tests
riedgar-ms May 1, 2024
c0769b6
Refactor tests
riedgar-ms May 1, 2024
2a60c94
Merge branch 'main' into riedgar-ms/azure-ai-studio-support-01
riedgar-ms May 1, 2024
9c755d6
Start doc writing
riedgar-ms May 2, 2024
21ee13f
Add some basic docs
riedgar-ms May 2, 2024
cdf679c
Use the new endpoint
riedgar-ms May 3, 2024
4281d7f
Handle optional import
riedgar-ms May 3, 2024
7bf3d07
OpenAI guard mk II
riedgar-ms May 3, 2024
c973076
Merge remote-tracking branch 'upstream/main' into riedgar-ms/azure-ai…
riedgar-ms May 3, 2024
50847c5
Small fixes for mypy
riedgar-ms May 3, 2024
3d73c42
One suppression....
riedgar-ms May 3, 2024
64ec232
More mypy fixing
riedgar-ms May 3, 2024
2327c8a
Merge remote-tracking branch 'upstream/main' into riedgar-ms/azure-ai…
riedgar-ms May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions guidance/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AzureOpenAICompletion,
AzureOpenAIInstruct,
)
from ._azureai_studio import AzureAIStudioChat
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
from ._cohere import Cohere, CohereCompletion, CohereInstruct
Expand Down
161 changes: 161 additions & 0 deletions guidance/models/_azureai_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import hashlib
import json
import pathlib
import urllib.request

import diskcache as dc
import platformdirs

from ._model import Chat
from ._grammarless import GrammarlessEngine, Grammarless


class AzureAIStudioChatEngine(GrammarlessEngine):
def __init__(
self,
*,
tokenizer,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never try setting the tokeniser, and it appears that it eventually defaults to GPT2. I don't quite see why a remote model like this would even need it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so theoretically for token healing. However, I have a feeling that trying to figure out what tokeniser to use will be an exercise in fragility.

max_streaming_tokens: int,
timeout: float,
compute_log_probs: bool,
azureai_studio_endpoint: str,
azureai_model_deployment: str,
azureai_studio_key: str,
):
self._endpoint = azureai_studio_endpoint
self._deployment = azureai_model_deployment
self._api_key = azureai_studio_key

path = (
pathlib.Path(platformdirs.user_cache_dir("guidance"))
/ "azureaistudio.tokens"
)
self.cache = dc.Cache(path)

super().__init__(tokenizer, max_streaming_tokens, timeout, compute_log_probs)

def _hash_prompt(self, prompt):
# Copied from OpenAIChatEngine
return hashlib.sha256(f"{prompt}".encode()).hexdigest()

def _generator(self, prompt, temperature: float):
# Initial parts of this straight up copied from OpenAIChatEngine

# The next loop (or one like it) appears in several places,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on this?

This is a straight-up copy of what is in _open_ai.py, but there are almost-but-not-quite the same versions elsewhere. Chopping up the internal 'guidance state string' into individual chat messages feels like a task that all chat model will need - and hence should be centrall provided. Individual engines can then turn that format into whatever they precisely need.

# and quite possibly belongs in a library function or superclass
# That said, I'm not _completely sure that there aren't subtle
# differences between the various versions

# find the role tags
pos = 0
role_end = b"<|im_end|>"
messages = []
found = True
while found:

# find the role text blocks
found = False
for role_name, start_bytes in (
("system", b"<|im_start|>system\n"),
("user", b"<|im_start|>user\n"),
("assistant", b"<|im_start|>assistant\n"),
Comment on lines +84 to +86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do AzureAI models uniformly use the same role tags across their models? I don't think we can hard code a check for these start_bytes in this class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't coming from the model, surely? They're coming from guidance as it sends its current prompt back into this class, to be formatted and forwarded on to the model. Or.... have I managed to miss another action-at-a-distance part of guidance?

):
if prompt[pos:].startswith(start_bytes):
pos += len(start_bytes)
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
assert (
role_name == "assistant"
), "Bad chat format! Last role before gen needs to be assistant!"
break
btext = prompt[pos : pos + end_pos]
pos += end_pos + len(role_end)
messages.append(
{"role": role_name, "content": btext.decode("utf8")}
)
found = True
break

# Add nice exception if no role tags were used in the prompt.
# TODO: Move this somewhere more general for all chat models?
if messages == []:
raise ValueError(
f"The model is a Chat-based model and requires role tags in the prompt! \
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
to appropriately format your guidance program for this type of model."
)

# Update shared data state
self._reset_shared_data(prompt[:pos], temperature)

# Use cache only when temperature is 0
if temperature == 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This caching logic is a bit of a concern - at least for models where T=0 doesn't actually get determinism. And in general, it means that a bunch of our tests might not quite be doing what we think they're doing, because they may just be hitting the cache.

Copy link
Collaborator Author

@riedgar-ms riedgar-ms Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, the less I like the idea of a disk-based cache. In some ways it's worse on the OpenAI side, where both AzureOpenAI and OpenAI will wind up sharing the same cache.

How much speed up does it really give, compared to the Heisenbug potential it represents?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most models have more reliable temp=0 determinism now, but agree that perhaps sharing between AzureOpenAI and OpenAI is problematic (though there shouldn't be differences between the two APIs in theory?). I do think caching is a nice feature to have in general, as production workflows often have shared inputs coming in that save time and money to reuse.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a clear_cache argument to the constructor. I think I should add this to the OpenAI side as well, in a separate PR.

cache_key = self._hash_prompt(prompt)

# Check if the result is already in the cache
if cache_key in self.cache:
for chunk in self.cache[cache_key]:
yield chunk
return

# Now switch to the example code from AzureAI Studio
# Might want to rewrite this to the requests package
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved

# Prepare for the API call (this might be model specific....)
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved
parameters = dict(temperature=temperature)
payload = dict(input_data=dict(input_string=messages, parameters=parameters))

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self._api_key),
"azureml-model-deployment": self._deployment,
}

body = str.encode(json.dumps(payload))

req = urllib.request.Request(self._endpoint, body, headers)

response = urllib.request.urlopen(req)
result = json.loads(response.read())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic around result is another thing which might be model specific.


# Now back to OpenAIChatEngine, with slight modifications since
# this isn't a streaming API
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved
if temperature == 0:
cached_results = []

encoded_chunk = result["output"].encode("utf8")

yield encoded_chunk

if temperature == 0:
cached_results.append(encoded_chunk)

# Cache the results after the generator is exhausted
if temperature == 0:
self.cache[cache_key] = cached_results


class AzureAIStudioChat(Grammarless, Chat):
def __init__(
self,
azureai_studio_endpoint: str,
azureai_studio_deployment: str,
azureai_studio_key: str,
tokenizer=None,
echo: bool = True,
max_streaming_tokens: int = 1000,
timeout: float = 0.5,
compute_log_probs: bool = False,
):
super().__init__(
AzureAIStudioChatEngine(
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_model_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
tokenizer=tokenizer,
max_streaming_tokens=max_streaming_tokens,
timeout=timeout,
compute_log_probs=compute_log_probs,
),
echo=echo,
)
39 changes: 17 additions & 22 deletions tests/models/test_azureai_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,18 @@

from guidance import assistant, gen, models, system, user

from ..utils import env_or_fail

# Everything in here needs credentials to work
# Mark is configured in pyproject.toml
pytestmark = pytest.mark.needs_credentials


def _env_or_fail(var_name: str) -> str:
env_value = os.getenv(var_name, None)

assert env_value is not None, f"Env '{var_name}' not found"

return env_value


def test_azureai_openai_chat_smoke(rate_limiter):
azureai_endpoint = _env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = _env_or_fail("AZUREAI_CHAT_KEY")
model = _env_or_fail("AZUREAI_CHAT_MODEL")
azureai_endpoint = env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_CHAT_KEY")
model = env_or_fail("AZUREAI_CHAT_MODEL")

lm = models.AzureOpenAI(
model=model, azure_endpoint=azureai_endpoint, api_key=azureai_key
Expand All @@ -45,9 +40,9 @@ def test_azureai_openai_chat_smoke(rate_limiter):


def test_azureai_openai_chat_alt_args(rate_limiter):
azureai_endpoint = _env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = _env_or_fail("AZUREAI_CHAT_KEY")
model = _env_or_fail("AZUREAI_CHAT_MODEL")
azureai_endpoint = env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_CHAT_KEY")
model = env_or_fail("AZUREAI_CHAT_MODEL")

parsed_url = urlparse(azureai_endpoint)
parsed_query = parse_qs(parsed_url.query)
Expand Down Expand Up @@ -78,9 +73,9 @@ def test_azureai_openai_chat_alt_args(rate_limiter):


def test_azureai_openai_completion_smoke(rate_limiter):
azureai_endpoint = _env_or_fail("AZUREAI_COMPLETION_ENDPOINT")
azureai_key = _env_or_fail("AZUREAI_COMPLETION_KEY")
model = _env_or_fail("AZUREAI_COMPLETION_MODEL")
azureai_endpoint = env_or_fail("AZUREAI_COMPLETION_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_COMPLETION_KEY")
model = env_or_fail("AZUREAI_COMPLETION_MODEL")

lm = models.AzureOpenAI(
model=model, azure_endpoint=azureai_endpoint, api_key=azureai_key
Expand All @@ -93,9 +88,9 @@ def test_azureai_openai_completion_smoke(rate_limiter):


def test_azureai_openai_completion_alt_args(rate_limiter):
azureai_endpoint = _env_or_fail("AZUREAI_COMPLETION_ENDPOINT")
azureai_key = _env_or_fail("AZUREAI_COMPLETION_KEY")
model = _env_or_fail("AZUREAI_COMPLETION_MODEL")
azureai_endpoint = env_or_fail("AZUREAI_COMPLETION_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_COMPLETION_KEY")
model = env_or_fail("AZUREAI_COMPLETION_MODEL")

parsed_url = urlparse(azureai_endpoint)
parsed_query = parse_qs(parsed_url.query)
Expand All @@ -118,9 +113,9 @@ def test_azureai_openai_completion_alt_args(rate_limiter):


def test_azureai_openai_chat_loop(rate_limiter):
azureai_endpoint = _env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = _env_or_fail("AZUREAI_CHAT_KEY")
model = _env_or_fail("AZUREAI_CHAT_MODEL")
azureai_endpoint = env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_CHAT_KEY")
model = env_or_fail("AZUREAI_CHAT_MODEL")

lm = models.AzureOpenAI(
model=model, azure_endpoint=azureai_endpoint, api_key=azureai_key
Expand Down
37 changes: 37 additions & 0 deletions tests/models/test_azureai_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

import pytest

from guidance import assistant, gen, models, system, user

from ..utils import env_or_fail

# Everything in here needs credentials to work
# Mark is configured in pyproject.toml
pytestmark = pytest.mark.needs_credentials


def test_azureai_openai_chat_smoke(rate_limiter):
azureai_studio_endpoint = env_or_fail("AZURE_AI_STUDIO_ENDPOINT")
azureai_studio_deployment = env_or_fail("AZURE_AI_STUDIO_DEPLOYMENT")
azureai_studio_key = env_or_fail("AZURE_AI_STUDIO_KEY")

lm = models.AzureAIStudioChat(
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_studio_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
)
assert isinstance(lm, models.AzureAIStudioChat)

with system():
lm += "You are a math wiz."

with user():
lm += "What is 1 + 1?"

with assistant():
lm += gen(max_tokens=10, name="text")
lm += "Pick a number: "

print(str(lm))
assert len(lm["text"]) > 0
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

opanai_model_cache = {}

def env_or_fail(var_name: str) -> str:
env_value = os.getenv(var_name, None)

assert env_value is not None, f"Env '{var_name}' not found"

return env_value

def get_model(model_name, caching=False, **kwargs):
"""Get an LLM by name."""
Expand Down
Loading