Skip to content

Commit

Permalink
Upgrade OpenAI SDK (#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Feb 20, 2024
1 parent 898149d commit 149141c
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 49 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ mistral =
mistralai~=0.0.11

openai =
openai~=0.27.8
openai~=1.0
tiktoken~=0.3.3

google =
Expand Down
3 changes: 2 additions & 1 deletion src/helm/benchmark/test_model_deployment_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class TestModelProperties:
@pytest.mark.parametrize("deployment_name", [deployment.name for deployment in ALL_MODEL_DEPLOYMENTS])
def test_models_has_window_service(self, deployment_name: str):
with TemporaryDirectory() as tmpdir:
auto_client = AutoClient({}, tmpdir, BlackHoleCacheBackendConfig())
credentials = {"openaiApiKey": "test-openai-api-key"}
auto_client = AutoClient(credentials, tmpdir, BlackHoleCacheBackendConfig())
auto_tokenizer = AutoTokenizer({}, BlackHoleCacheBackendConfig())
tokenizer_service = get_tokenizer_service(tmpdir, BlackHoleCacheBackendConfig())

Expand Down
5 changes: 2 additions & 3 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from helm.common.object_spec import create_object, inject_object_spec_args
from helm.common.request import Request, RequestResult
from helm.proxy.clients.client import Client
from helm.proxy.clients.moderation_api_client import ModerationAPIClient
from helm.proxy.critique.critique_client import CritiqueClient
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
from helm.proxy.retry import NonRetriableException, retry_request
Expand Down Expand Up @@ -150,10 +151,8 @@ def get_toxicity_classifier_client(self) -> ToxicityClassifierClient:
cache_config: CacheConfig = self.cache_backend_config.get_cache_config("perspectiveapi")
return PerspectiveAPIClient(self.credentials.get("perspectiveApiKey", ""), cache_config)

def get_moderation_api_client(self):
def get_moderation_api_client(self) -> ModerationAPIClient:
"""Get the ModerationAPI client."""
from .moderation_api_client import ModerationAPIClient

cache_config: CacheConfig = self.cache_backend_config.get_cache_config("ModerationAPI")
return ModerationAPIClient(self.credentials.get("openaiApiKey", ""), cache_config)

Expand Down
13 changes: 4 additions & 9 deletions src/helm/proxy/clients/image_generation/dalle2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

try:
import openai
from openai import OpenAI
except ModuleNotFoundError as missing_module_exception:
handle_module_not_found_error(missing_module_exception, ["openai"])

Expand Down Expand Up @@ -59,12 +60,9 @@ def __init__(
self.file_cache: FileCache = file_cache
self._cache = Cache(cache_config)

self.client = OpenAI(api_key=api_key, organization=org_id)
self.moderation_api_client: ModerationAPIClient = moderation_api_client

self.org_id: Optional[str] = org_id
self.api_key: Optional[str] = api_key
self.api_base: str = "https://api.openai.com/v1"

def get_content_policy_violated_result(self, request: Request) -> RequestResult:
"""
Return a RequestResult with no images and a finish reason indicating that the prompt / generated images
Expand Down Expand Up @@ -134,10 +132,7 @@ def generate_with_dalle_api(self, raw_request: Dict[str, Any]) -> Dict:
"""
Makes a single request to generate the images with the DALL-E API.
"""
openai.organization = self.org_id
openai.api_key = self.api_key
openai.api_base = self.api_base
result = openai.Image.create(**raw_request)
result = self.client.images.generate(**raw_request).model_dump(mode="json")
assert "data" in result, f"Invalid response: {result} from prompt: {raw_request['prompt']}"

for image in result["data"]:
Expand Down Expand Up @@ -170,7 +165,7 @@ def do_it():

cache_key = CachingClient.make_cache_key(raw_request, request)
response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
return self.handle_openai_error(request, e)

completions: List[Sequence] = [
Expand Down
2 changes: 1 addition & 1 deletion src/helm/proxy/clients/image_generation/dalle3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def do_it():

responses.append(response)
all_cached = all_cached and cached
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
return self.handle_openai_error(request, e)

completions: List[Sequence] = []
Expand Down
12 changes: 8 additions & 4 deletions src/helm/proxy/clients/moderation_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class ModerationAPIClient:
VIOLENCE_GRAPHIC: str = "violence/graphic"

def __init__(self, api_key: str, cache_config: CacheConfig):
self.api_key = api_key
self.cache = Cache(cache_config)
try:
from openai import OpenAI
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["openai"])
# TODO: Add OpenAI organization.
self.client = OpenAI(api_key=api_key)

def get_moderation_results(self, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
"""
Expand All @@ -53,13 +58,12 @@ def get_moderation_results(self, request: ModerationAPIRequest) -> ModerationAPI
try:

def do_it():
openai.api_key = self.api_key
result = openai.Moderation.create(input=request.text)
result = self.client.moderations.create(input=request.text).model_dump(mode="json")
assert "results" in result and len(result["results"]) > 0, f"Invalid response: {result}"
return result

response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
error: str = f"Moderation API error: {e}"
return ModerationAPIRequestResult(
success=False, cached=False, error=error, flagged=None, flagged_results=None, scores=None
Expand Down
43 changes: 14 additions & 29 deletions src/helm/proxy/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

try:
import openai
from openai import OpenAI
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["openai"])


ORIGINAL_COMPLETION_ATTRIBUTES = openai.api_resources.completion.Completion.__bases__


class OpenAIClient(CachingClient):
END_OF_TEXT: str = "<|endoftext|>"

Expand All @@ -43,25 +41,17 @@ def __init__(
cache_config: CacheConfig,
api_key: Optional[str] = None,
org_id: Optional[str] = None,
base_url: Optional[str] = None,
):
super().__init__(cache_config=cache_config)
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.org_id: Optional[str] = org_id
self.api_key: Optional[str] = api_key
self.api_base: str = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, organization=org_id, base_url=base_url)

def _is_chat_model_engine(self, model_engine: str):
return model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4")

def _set_access_info(self):
# Following https://beta.openai.com/docs/api-reference/authentication
# `organization` can be set to None.
openai.organization = self.org_id
openai.api_key = self.api_key
openai.api_base = self.api_base

def _get_cache_key(self, raw_request, request):
def _get_cache_key(self, raw_request: Dict, request: Request):
cache_key = CachingClient.make_cache_key(raw_request, request)
if is_vlm(request.model):
assert request.multimodal_prompt is not None
Expand All @@ -74,17 +64,17 @@ def _make_embedding_request(self, request: Request) -> RequestResult:
raw_request: Dict[str, Any]
raw_request = {
"input": request.prompt,
"engine": request.model_engine,
# Note: In older deprecated versions of the OpenAI API, "model" used to be "engine".
"model": request.model_engine,
}

def do_it():
self._set_access_info()
return openai.Embedding.create(**raw_request)
return self.client.embeddings.create(**raw_request).model_dump(mode="json")

try:
cache_key = self._get_cache_key(raw_request, request)
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
error: str = f"OpenAI error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])

Expand Down Expand Up @@ -168,13 +158,12 @@ def _make_chat_request(self, request: Request) -> RequestResult:
raw_request.pop("stop")

def do_it():
self._set_access_info()
return openai.ChatCompletion.create(**raw_request)
return self.client.chat.completions.create(**raw_request).model_dump(mode="json")

try:
cache_key = self._get_cache_key(raw_request, request)
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
if self.INAPPROPRIATE_IMAGE_ERROR in str(e):
hlog(f"Failed safety check: {str(request)}")
empty_completion = Sequence(
Expand Down Expand Up @@ -227,7 +216,8 @@ def do_it():

def _make_completion_request(self, request: Request) -> RequestResult:
raw_request: Dict[str, Any] = {
"engine": request.model_engine,
# Note: In older deprecated versions of the OpenAI API, "model" used to be "engine".
"model": request.model_engine,
"prompt": request.prompt,
"temperature": request.temperature,
"n": request.num_completions,
Expand All @@ -247,14 +237,12 @@ def _make_completion_request(self, request: Request) -> RequestResult:
raw_request["logprobs"] = max(raw_request["logprobs"], raw_request["n"])

def do_it():
self._set_access_info()
openai.api_resources.completion.Completion.__bases__ = ORIGINAL_COMPLETION_ATTRIBUTES
return openai.Completion.create(**raw_request)
return self.client.completions.create(**raw_request).model_dump(mode="json")

try:
cache_key = self._get_cache_key(raw_request, request)
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
error: str = f"OpenAI error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])

Expand Down Expand Up @@ -294,9 +282,6 @@ def do_it():
)

def make_request(self, request: Request) -> RequestResult:
if self.api_key is None:
raise ValueError("OpenAI API key is required")

if request.embedding:
return self._make_embedding_request(request)
elif self._is_chat_model_engine(request.model_engine):
Expand Down
5 changes: 4 additions & 1 deletion src/helm/proxy/services/server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from helm.common.hierarchical_logger import hlog
from helm.proxy.accounts import Accounts, Account
from helm.proxy.clients.auto_client import AutoClient
from helm.proxy.clients.moderation_api_client import ModerationAPIClient
from helm.proxy.clients.perspective_api_client import PerspectiveAPIClient
from helm.proxy.clients.image_generation.nudity_check_client import NudityCheckClient
from helm.proxy.clients.gcs_client import GCSClient
Expand Down Expand Up @@ -68,9 +69,9 @@ def __init__(
self.tokenizer = AutoTokenizer(credentials, cache_backend_config)
self.token_counter = AutoTokenCounter(self.tokenizer)
self.accounts = Accounts(accounts_path, root_mode=root_mode)
self.moderation_api_client = self.client.get_moderation_api_client()

# Lazily instantiate the following clients
self.moderation_api_client: Optional[ModerationAPIClient] = None
self.toxicity_classifier_client: Optional[ToxicityClassifierClient] = None
self.perspective_api_client: Optional[PerspectiveAPIClient] = None
self.nudity_check_client: Optional[NudityCheckClient] = None
Expand Down Expand Up @@ -200,6 +201,8 @@ def get_toxicity_scores_with_retry(request: PerspectiveAPIRequest) -> Perspectiv
def get_moderation_results(self, auth: Authentication, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
@retry_request
def get_moderation_results_with_retry(request: ModerationAPIRequest) -> ModerationAPIRequestResult:
if not self.moderation_api_client:
self.moderation_api_client = self.client.get_moderation_api_client()
return self.moderation_api_client.get_moderation_results(request)

self.accounts.authenticate(auth)
Expand Down

0 comments on commit 149141c

Please sign in to comment.