Skip to content

Commit

Permalink
add azure openai (#26)
Browse files Browse the repository at this point in the history
* add azure openai

* discribe missing env

* move azure condition to top
  • Loading branch information
ruanwz authored Jul 9, 2024
1 parent d6befc6 commit 8bb828a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
9 changes: 7 additions & 2 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
if engine_name in __ENGINE_NAME_SHORTCUTS__:
engine_name = __ENGINE_NAME_SHORTCUTS__[engine_name]

if "seed" in kwargs and engine_name not in ["gpt-4", "gpt-3.5"]:
if "seed" in kwargs and "gpt-4" not in engine_name and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name:
raise ValueError(f"Seed is currently supported only for OpenAI engines, not {engine_name}")

if (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)):
if engine_name.startswith("azure"):
from .openai import AzureChatOpenAI
# remove engine_name "azure-" prefix
engine_name = engine_name[6:]
return AzureChatOpenAI(model_string=engine_name, **kwargs)
elif (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)):
from .openai import ChatOpenAI
return ChatOpenAI(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs)
elif "claude" in engine_name:
Expand Down
43 changes: 42 additions & 1 deletion textgrad/engine/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
try:
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
except ImportError:
raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.")

Expand Down Expand Up @@ -138,3 +138,44 @@ def _generate_from_multiple_input(
response_text = response.choices[0].message.content
self._save_cache(cache_key, response_text)
return response_text

class AzureChatOpenAI(ChatOpenAI):
def __init__(
self,
model_string="gpt-35-turbo",
system_prompt=ChatOpenAI.DEFAULT_SYSTEM_PROMPT,
**kwargs):
"""
Initializes an interface for interacting with Azure's OpenAI models.
This class extends the ChatOpenAI class to use Azure's OpenAI API instead of OpenAI's API. It sets up the necessary client with the appropriate API version, API key, and endpoint from environment variables.
:param model_string: The model identifier for Azure OpenAI. Defaults to 'gpt-3.5-turbo'.
:param system_prompt: The default system prompt to use when generating responses. Defaults to ChatOpenAI's default system prompt.
:param kwargs: Additional keyword arguments to pass to the ChatOpenAI constructor.
Environment variables:
- AZURE_OPENAI_API_KEY: The API key for authenticating with Azure OpenAI.
- AZURE_OPENAI_API_BASE: The base URL for the Azure OpenAI API.
- AZURE_OPENAI_API_VERSION: The API version to use. Defaults to '2023-07-01-preview' if not set.
Raises:
ValueError: If the AZURE_OPENAI_API_KEY environment variable is not set.
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache

super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs)

self.system_prompt = system_prompt
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview")
if os.getenv("AZURE_OPENAI_API_KEY") is None:
raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.")

self.client = AzureOpenAI(
api_version=api_version,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
azure_deployment=model_string,
)
self.model_string = model_string

0 comments on commit 8bb828a

Please sign in to comment.