Skip to content

Commit

Permalink
Switched to AzureOpenAI for api_type=="azure" (microsoft#1232)
Browse files Browse the repository at this point in the history
* Switched to AzureOpenAI for api_type=="azure"

* Setting AzureOpenAI to empty object if no `openai`

* extra_ and openai_ kwargs

* test_client, support for Azure and "gpt-35-turbo-instruct"

* instruct/azure model in test_client_stream

* generalize aoai support (#1)

* generalize aoai support

* Null check, fixing tests

* cleanup test

---------

Co-authored-by: Maxim Saplin <[email protected]>

* Returning back model names for instruct

* process model in create

* None check

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
maxim-saplin and sonichi authored Jan 17, 2024
1 parent 39182cc commit 00dbcb2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 60 deletions.
74 changes: 31 additions & 43 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from autogen.oai import completion

from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
from autogen._pydantic import model_dump

Expand All @@ -21,9 +21,10 @@
except ImportError:
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object
AzureOpenAI = object
else:
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -52,8 +53,18 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""

cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
extra_kwargs = {
"cache_seed",
"filter_func",
"allow_format_str_template",
"context",
"api_version",
"api_type",
"tags",
}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
openai_kwargs = openai_kwargs | aopenai_kwargs
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -105,46 +116,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs]

def _process_for_azure(
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
api_version = extra_kwargs.get("api_version")
if api_version is not None and query_segment not in config:
config[query_segment] = {"api-version": api_version}
if segment == "default":
# remove the api_version from extra_kwargs
extra_kwargs.pop("api_version")
if segment == "extra":
return
# deal with api_type
api_type = extra_kwargs.get("api_type")
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
config[headers_segment] = {"api-key": api_key}
# remove the api_type from extra_kwargs
extra_kwargs.pop("api_type")
# deal with model
model = extra_kwargs.get("model")
if model is None:
return
if "gpt-3.5" in model:
# hack for azure gpt-3.5
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
base_url = config.get("base_url")
if base_url is None:
raise ValueError("to use azure openai api, base_url must be specified.")
suffix = f"/openai/deployments/{model}"
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix

def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
return openai_config, extra_kwargs

def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand All @@ -156,10 +131,22 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
For Azure models/deployment names there's a convenience modification of model removing dots in
the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
client = OpenAI(**openai_config)
api_type = config.get("api_type")
if api_type is not None and api_type.startswith("azure"):
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
if openai_config["azure_deployment"] is not None:
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
client = AzureOpenAI(**openai_config)
else:
client = OpenAI(**openai_config)
return client

@classmethod
Expand Down Expand Up @@ -242,8 +229,9 @@ def yes_or_no_filter(context, response):
full_config = {**config, **self._config_list[i]}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# process for azure
self._process_for_azure(create_config, extra_kwargs, "extra")
api_type = extra_kwargs.get("api_type")
if api_type and api_type.startswith("azure") and "model" in create_config:
create_config["model"] = create_config["model"].replace(".", "")
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
Expand Down
40 changes: 24 additions & 16 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@ def test_aoai_chat_completion():
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
# for config in config_list:
# print(config)
# client = OpenAIWrapper(**config)
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_completion_object(response))

# test dialect
config = config_list[0]
config["azure_deployment"] = config["model"]
config["azure_endpoint"] = config.pop("base_url")
client = OpenAIWrapper(**config)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_completion_object(response))
Expand Down Expand Up @@ -93,21 +98,23 @@ def test_chat_completion():
def test_completion():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model)
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.parametrize(
"cache_seed, model",
"cache_seed",
[
(None, "gpt-3.5-turbo-instruct"),
(42, "gpt-3.5-turbo-instruct"),
None,
42,
],
)
def test_cost(cache_seed, model):
def test_cost(cache_seed):
config_list = config_list_openai_aoai(KEY_LOC)
model = "gpt-3.5-turbo-instruct"
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=", model=model)
print(response.cost)
Expand All @@ -117,7 +124,8 @@ def test_cost(cache_seed, model):
def test_usage_summary():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None)
model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+3=", model=model, cache_seed=None)

# usage should be recorded
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
Expand All @@ -138,15 +146,15 @@ def test_usage_summary():
assert client.total_usage_summary is None, "total_usage_summary should be None"

# actual usage and all usage should be different
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42)
response = client.create(prompt="1+3=", model=model, cache_seed=42)
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
assert client.actual_usage_summary is None, "No actual cost should be recorded"


if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# test_cost()
test_usage_summary()
# # test_cost()
# test_usage_summary()
4 changes: 3 additions & 1 deletion test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None:
def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
# Azure can't have dot in model/deployment name
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model, stream=True)
print(response)
print(client.extract_text_or_completion_object(response))

Expand Down

0 comments on commit 00dbcb2

Please sign in to comment.