Skip to content

Commit

Permalink
Support GPT-4
Browse files Browse the repository at this point in the history
  • Loading branch information
WongSaang committed Mar 27, 2023
1 parent 621cef0 commit 92240c8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 41 deletions.
105 changes: 66 additions & 39 deletions chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ def delete_all(self, request):
return Response(status=204)


MODELS = {
'gpt-3.5-turbo': {
'name': 'gpt-3.5-turbo',
'max_tokens': 4096,
'max_prompt_tokens': 3096,
'max_response_tokens': 1000
},
'gpt-4': {
'name': 'gpt-4',
'max_tokens': 8192,
'max_prompt_tokens': 6196,
'max_response_tokens': 2000
}
}


def sse_pack(event, data):
# Format data as an SSE message
packet = "event: %s\n" % event
Expand All @@ -107,12 +123,10 @@ def gen_title(request):
{"role": "user", "content": 'Generate a short title for the following content, no more than 10 words: \n\n "%s"' % message.message},
]

model = get_current_model()

myOpenai = get_openai()
try:
openai_response = myOpenai.ChatCompletion.create(
model=model['name'],
model='gpt-3.5-turbo-0301',
messages=messages,
max_tokens=256,
temperature=0.5,
Expand Down Expand Up @@ -145,7 +159,11 @@ def conversation(request):
},
status=status.HTTP_400_BAD_REQUEST
)
model = get_current_model()
model_name = request.data.get('name')
if model_name is None:
model = get_current_model()
else:
model = get_current_model(model_name)
message = request.data.get('message')
conversation_id = request.data.get('conversationId')
max_tokens = request.data.get('max_tokens', model['max_response_tokens'])
Expand All @@ -171,7 +189,7 @@ def conversation(request):
message_obj.save()

try:
messages = build_messages(conversation_obj, web_search_params)
messages = build_messages(model, conversation_obj, web_search_params)

if settings.DEBUG:
print(messages)
Expand All @@ -192,16 +210,22 @@ def stream_content():

myOpenai = get_openai(openai_api_key)

openai_response = myOpenai.ChatCompletion.create(
model=model['name'],
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stream=True,
)
try:
openai_response = myOpenai.ChatCompletion.create(
model=model['name'],
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stream=True,
)
except Exception as e:
yield sse_pack('error', {
'error': str(e)
})
return
collected_events = []
completion_text = ''
# iterate through the stream of events
Expand Down Expand Up @@ -232,8 +256,7 @@ def stream_content():
return StreamingHttpResponse(stream_content(), content_type='text/event-stream')


def build_messages(conversation_obj, web_search_params):
model = get_current_model()
def build_messages(model, conversation_obj, web_search_params):

ordered_messages = Message.objects.filter(conversation=conversation_obj).order_by('created_at')
ordered_messages_list = list(ordered_messages)
Expand Down Expand Up @@ -267,14 +290,8 @@ def build_messages(conversation_obj, web_search_params):
return system_messages + messages


def get_current_model():
model = {
'name': 'gpt-3.5-turbo',
'max_tokens': 4096,
'max_prompt_tokens': 3096,
'max_response_tokens': 1000
}
return model
def get_current_model(model="gpt-3.5-turbo"):
return MODELS[model]


def get_openai_api_key():
Expand All @@ -284,26 +301,36 @@ def get_openai_api_key():
return None


def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo": # note: future models may deviate from this
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}. See
https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to
tokens.""")
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def get_openai(openai_api_key = None):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
Django==4.1.7
gunicorn==20.1.0
openai~=0.27.0
openai~=0.27.2
psycopg2~=2.9.5
python-dotenv~=0.21.1
dj-database-url~=1.2.0
djangorestframework~=3.14.0
tiktoken~=0.3.0
tiktoken~=0.3.2
djangorestframework-simplejwt~=5.2.2
mysqlclient~=2.1.1
django-allauth~=0.52.0
Expand Down

0 comments on commit 92240c8

Please sign in to comment.