forked from Sider-ai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopenai.py
78 lines (62 loc) · 2.41 KB
/
openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import List
import openai
import os
from loguru import logger
from tenacity import retry, wait_random_exponential, stop_after_attempt
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3))
def get_embeddings(texts: List[str]) -> List[List[float]]:
"""
Embed texts using OpenAI's ada model.
Args:
texts: The list of texts to embed.
Returns:
A list of embeddings, each of which is a list of floats.
Raises:
Exception: If the OpenAI API call fails.
"""
# Call the OpenAI API to get the embeddings
# NOTE: Azure Open AI requires deployment id
deployment = os.environ.get("OPENAI_EMBEDDINGMODEL_DEPLOYMENTID")
response = {}
if deployment is None:
response = openai.Embedding.create(input=texts, model=EMBEDDING_MODEL)
else:
response = openai.Embedding.create(input=texts, deployment_id=deployment)
# Extract the embedding data from the response
data = response["data"] # type: ignore
# Return the embeddings as a list of lists of floats
return [result["embedding"] for result in data]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3))
def get_chat_completion(
messages,
model="gpt-3.5-turbo", # use "gpt-4" for better results
deployment_id=None,
):
"""
Generate a chat completion using OpenAI's chat completion API.
Args:
messages: The list of messages in the chat history.
model: The name of the model to use for the completion. Default is gpt-3.5-turbo, which is a fast, cheap and versatile model. Use gpt-4 for higher quality but slower results.
Returns:
A string containing the chat completion.
Raises:
Exception: If the OpenAI API call fails.
"""
# call the OpenAI chat completion API with the given messages
# Note: Azure Open AI requires deployment id
response = {}
if deployment_id == None:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
else:
response = openai.ChatCompletion.create(
deployment_id=deployment_id,
messages=messages,
)
choices = response["choices"] # type: ignore
completion = choices[0].message.content.strip()
logger.info(f"Completion: {completion}")
return completion