From 8bb828afb5088f7048af0466613cc240b21094f5 Mon Sep 17 00:00:00 2001 From: David Ruan Date: Wed, 10 Jul 2024 07:42:19 +0800 Subject: [PATCH] add azure openai (#26) * add azure openai * discribe missing env * move azure condition to top --- textgrad/engine/__init__.py | 9 ++++++-- textgrad/engine/openai.py | 43 ++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index 0625168..25b3ebd 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -30,10 +30,15 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: if engine_name in __ENGINE_NAME_SHORTCUTS__: engine_name = __ENGINE_NAME_SHORTCUTS__[engine_name] - if "seed" in kwargs and engine_name not in ["gpt-4", "gpt-3.5"]: + if "seed" in kwargs and "gpt-4" not in engine_name and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name: raise ValueError(f"Seed is currently supported only for OpenAI engines, not {engine_name}") - if (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)): + if engine_name.startswith("azure"): + from .openai import AzureChatOpenAI + # remove engine_name "azure-" prefix + engine_name = engine_name[6:] + return AzureChatOpenAI(model_string=engine_name, **kwargs) + elif (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)): from .openai import ChatOpenAI return ChatOpenAI(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs) elif "claude" in engine_name: diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index a6d10ec..0a922b4 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -1,5 +1,5 @@ try: - from openai import OpenAI + from openai import OpenAI, AzureOpenAI except ImportError: raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.") @@ -138,3 +138,44 @@ def _generate_from_multiple_input( response_text = response.choices[0].message.content self._save_cache(cache_key, response_text) return response_text + +class AzureChatOpenAI(ChatOpenAI): + def __init__( + self, + model_string="gpt-35-turbo", + system_prompt=ChatOpenAI.DEFAULT_SYSTEM_PROMPT, + **kwargs): + """ + Initializes an interface for interacting with Azure's OpenAI models. + + This class extends the ChatOpenAI class to use Azure's OpenAI API instead of OpenAI's API. It sets up the necessary client with the appropriate API version, API key, and endpoint from environment variables. + + :param model_string: The model identifier for Azure OpenAI. Defaults to 'gpt-3.5-turbo'. + :param system_prompt: The default system prompt to use when generating responses. Defaults to ChatOpenAI's default system prompt. + :param kwargs: Additional keyword arguments to pass to the ChatOpenAI constructor. + + Environment variables: + - AZURE_OPENAI_API_KEY: The API key for authenticating with Azure OpenAI. + - AZURE_OPENAI_API_BASE: The base URL for the Azure OpenAI API. + - AZURE_OPENAI_API_VERSION: The API version to use. Defaults to '2023-07-01-preview' if not set. + + Raises: + ValueError: If the AZURE_OPENAI_API_KEY environment variable is not set. + """ + root = platformdirs.user_cache_dir("textgrad") + cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache + + super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs) + + self.system_prompt = system_prompt + api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview") + if os.getenv("AZURE_OPENAI_API_KEY") is None: + raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.") + + self.client = AzureOpenAI( + api_version=api_version, + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"), + azure_deployment=model_string, + ) + self.model_string = model_string