Skip to content

Commit

Permalink
Merge pull request #1 from WongSaang/rest-auth
Browse files Browse the repository at this point in the history
Support for the official ChatGPT model: gpt-3.5-turbo
  • Loading branch information
WongSaang authored Mar 3, 2023
2 parents beecb88 + 89b7876 commit 2b46868
Show file tree
Hide file tree
Showing 15 changed files with 217 additions and 92 deletions.
6 changes: 1 addition & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@ COPY requirements.txt ./

RUN pip install --no-cache-dir -i https://mirrors.cloud.tencent.com/pypi/simple -r requirements.txt

RUN groupadd -r appgroup && useradd -r -g appgroup appuser && mkdir -p /app && chown appuser /app

USER appuser

WORKDIR /app

COPY --chown=appuser . .
COPY . .

RUN python manage.py check --deploy \
&& python manage.py collectstatic --no-input \
Expand Down
Empty file added account/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions account/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.contrib import admin

# Register your models here.
8 changes: 8 additions & 0 deletions account/allauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from allauth.account.adapter import DefaultAccountAdapter
from allauth.utils import build_absolute_uri

class AccountAdapter(DefaultAccountAdapter):

def get_email_confirmation_url(self, request, emailconfirmation):
location = '/account/verify-email/{}'.format(emailconfirmation.key)
return build_absolute_uri(None, location)
6 changes: 6 additions & 0 deletions account/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class AccountConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'account'
Empty file added account/migrations/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions account/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.db import models

# Create your models here.
40 changes: 40 additions & 0 deletions account/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from rest_framework import exceptions, serializers

# Get the UserModel
UserModel = get_user_model()

class UserDetailsSerializer(serializers.ModelSerializer):
"""
User model w/o password
"""

@staticmethod
def validate_username(username):
if 'allauth.account' not in settings.INSTALLED_APPS:
# We don't need to call the all-auth
# username validator unless its installed
return username

from allauth.account.adapter import get_adapter
username = get_adapter().clean_username(username)
return username

class Meta:
extra_fields = []
# see https://github.com/iMerica/dj-rest-auth/issues/181
# UserModel.XYZ causing attribute error while importing other
# classes from `serializers.py`. So, we need to check whether the auth model has
# the attribute or not
if hasattr(UserModel, 'USERNAME_FIELD'):
extra_fields.append(UserModel.USERNAME_FIELD)
if hasattr(UserModel, 'EMAIL_FIELD'):
extra_fields.append(UserModel.EMAIL_FIELD)
if hasattr(UserModel, 'first_name'):
extra_fields.append('first_name')
if hasattr(UserModel, 'last_name'):
extra_fields.append('last_name')
model = UserModel
fields = (*extra_fields,)
read_only_fields = ('email',)
3 changes: 3 additions & 0 deletions account/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.test import TestCase

# Create your tests here.
4 changes: 4 additions & 0 deletions account/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from django.shortcuts import render

# Create your views here.
from dj_rest_auth.views import LoginView
153 changes: 91 additions & 62 deletions chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import tiktoken
from .models import Conversation, Message, Setting
from django.conf import settings
from django.http import StreamingHttpResponse
from rest_framework import viewsets, status
from rest_framework.response import Response
Expand All @@ -15,7 +16,7 @@

class ConversationViewSet(viewsets.ModelViewSet):
serializer_class = ConversationSerializer
authentication_classes = [JWTAuthentication]
# authentication_classes = [JWTAuthentication]
permission_classes = [IsAuthenticated]

def get_queryset(self):
Expand All @@ -24,21 +25,12 @@ def get_queryset(self):

class MessageViewSet(viewsets.ModelViewSet):
serializer_class = MessageSerializer
authentication_classes = [JWTAuthentication]
# authentication_classes = [JWTAuthentication]
permission_classes = [IsAuthenticated]

def get_queryset(self):
return Message.objects.filter(conversation_id=self.request.query_params.get('conversationId')).order_by('created_at')


@api_view(['GET'])
@authentication_classes([JWTAuthentication])
@permission_classes([IsAuthenticated])
def get_current_user(request):
user = request.user
return Response({
'username': user.username,
})
return Message.objects.filter(conversation_id=self.request.query_params.get('conversationId')).order_by(
'created_at')


def sse_pack(event, data):
Expand All @@ -50,20 +42,20 @@ def sse_pack(event, data):


@api_view(['POST'])
@authentication_classes([JWTAuthentication])
# @authentication_classes([JWTAuthentication])
@permission_classes([IsAuthenticated])
def gen_title(request):
conversation_id = request.data.get('conversationId')
conversation = Conversation.objects.get(id=conversation_id)
conversation_obj = Conversation.objects.get(id=conversation_id)
message = Message.objects.filter(conversation_id=conversation_id).order_by('created_at').first()
prompt = f'''Generate a title of ten words or less from the following text:
[{message.message}]
Title:
'''

openai.api_key = get_openai_api_key()
myOpenai = get_openai()
try:
openai_response = openai.Completion.create(
openai_response = myOpenai.Completion.create(
model='text-davinci-003',
prompt=prompt,
temperature=0.5,
Expand All @@ -77,20 +69,25 @@ def gen_title(request):
except:
title = 'Untitled Conversation'
# update the conversation title
conversation.topic = title
conversation.save()
conversation_obj.topic = title
conversation_obj.save()
return Response({
'title': title
})


@api_view(['POST'])
@authentication_classes([JWTAuthentication])
# @authentication_classes([JWTAuthentication])
@permission_classes([IsAuthenticated])
def conversation(request):
api_key = get_openai_api_key()
if api_key is None:
return Response({'error': 'The administrator has not set the API key'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response(
{
'error': 'The administrator has not set the API key'
},
status=status.HTTP_400_BAD_REQUEST
)
model = get_current_model()
message = request.data.get('message')
conversation_id = request.data.get('conversationId')
Expand All @@ -111,24 +108,34 @@ def conversation(request):
)
message_obj.save()

prompt = build_prompt(conversation_obj)
try:
messages = build_messages(conversation_obj)

if settings.DEBUG:
print(messages)
except ValueError as e:
return Response(
{
'error': e
},
status=status.HTTP_400_BAD_REQUEST
)
# print(prompt)

num_tokens = get_token_count(prompt)
num_tokens = num_tokens_from_messages(messages)
max_tokens = min(model['max_tokens'] - num_tokens, model['max_response_tokens'])

def stream_content():
openai.api_key = api_key
myOpenai = get_openai()

openai_response = openai.Completion.create(
openai_response = myOpenai.ChatCompletion.create(
model=model['name'],
prompt=prompt,
messages=messages,
max_tokens=max_tokens,
temperature=0.9,
temperature=0.7,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.6,
stop=[" Human:", " AI:"],
frequency_penalty=0,
presence_penalty=0,
stream=True,
)
collected_events = []
Expand All @@ -139,10 +146,13 @@ def stream_content():
# print(event)
if event['choices'][0]['finish_reason'] is not None:
break
event_text = event['choices'][0]['text'] # extract the text
completion_text += event_text # append the text
# print(event)
yield sse_pack('message', {'content': event_text})
# if debug
if settings.DEBUG:
print(event)
if 'content' in event['choices'][0]['delta']:
event_text = event['choices'][0]['delta']['content']
completion_text += event_text # append the text
yield sse_pack('message', {'content': event_text})

ai_message_obj = Message(
conversation_id=conversation_obj.id,
Expand All @@ -156,59 +166,78 @@ def stream_content():
return StreamingHttpResponse(stream_content(), content_type='text/event-stream')


def build_prompt(conversation_obj):
def build_messages(conversation_obj):
model = get_current_model()

ordered_messages = Message.objects.filter(conversation=conversation_obj).order_by('created_at')
ordered_messages_list = list(ordered_messages)

ai_label = 'AI'
user_label = 'Human'
current_date_string = datetime.datetime.today().strftime('%B %d, %Y')
prompt_prefix = f'Instructions:\nYou are ChatGPT, a large language model trained by OpenAI.\nCurrent date: {current_date_string}\n'
prompt_suffix = f"{ai_label}:"
system_messages = [{"role": "system", "content": "You are a helpful assistant."}]

current_token_count = num_tokens_from_messages(system_messages, model['name'])

current_token_count = get_token_count(f"{prompt_prefix}{prompt_suffix}")
prompt_body = ''
max_token_count = model['max_prompt_tokens']

messages = []

while current_token_count < max_token_count and len(ordered_messages_list) > 0:
message = ordered_messages_list.pop()
role_label = ai_label if message.is_bot else user_label
message_string = f"{role_label}: {message.message}\n"
if prompt_body:
new_prompt_body = f"{message_string}{prompt_body}"
else:
new_prompt_body = f"{prompt_prefix}{message_string}{prompt_body}"

new_token_count = get_token_count(f"{prompt_prefix}{new_prompt_body}{prompt_suffix}")
role = "assistant" if message.is_bot else "user"
new_message = {"role": role, "content": message.message}
new_token_count = num_tokens_from_messages(system_messages + messages + [new_message])
if new_token_count > max_token_count:
if prompt_body:
if len(messages) > 0:
break
raise ValueError(f"Prompt is too long. Max token count is {max_token_count}, but prompt is {new_token_count} tokens long.")
prompt_body = new_prompt_body
raise ValueError(
f"Prompt is too long. Max token count is {max_token_count}, but prompt is {new_token_count} tokens long.")
messages.insert(0, new_message)
current_token_count = new_token_count

prompt = f"{prompt_body}{prompt_suffix}"
return system_messages + messages

return prompt

def get_current_model():
model = {
'name': 'text-davinci-003',
'max_tokens': 4097,
'max_prompt_tokens': 3097,
'name': 'gpt-3.5-turbo',
'max_tokens': 4096,
'max_prompt_tokens': 3096,
'max_response_tokens': 1000
}
return model


def get_openai_api_key():
row = Setting.objects.filter(name='openai_api_key').first()
if row:
return row.value
return None

def get_token_count(token):
model = get_current_model()
enc = tiktoken.encoding_for_model(model['name'])
return len(enc.encode(token))

def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo": # note: future models may deviate from this
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}. See
https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to
tokens.""")


def get_openai():
openai.api_key = get_openai_api_key()
proxy = os.getenv('OPENAI_API_PROXY')
if proxy:
openai.api_base = proxy
return openai
Loading

0 comments on commit 2b46868

Please sign in to comment.