From 1282bd6f5a87a344a182ba592d16b87a4bd9a6b8 Mon Sep 17 00:00:00 2001 From: Rafi Date: Thu, 9 Mar 2023 17:40:12 +0800 Subject: [PATCH] Add frequently used prompt function. --- chat/models.py | 7 +++++++ chat/serializers.py | 8 +++++++- chat/urls.py | 3 ++- chat/views.py | 29 +++++++++++++++++++++++++++-- 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/chat/models.py b/chat/models.py index c3344c6..4a058de 100644 --- a/chat/models.py +++ b/chat/models.py @@ -16,6 +16,13 @@ class Message(models.Model): created_at = models.DateTimeField(auto_now_add=True) +class Prompt(models.Model): + user = models.ForeignKey(User, on_delete=models.CASCADE) + prompt = models.TextField() + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Setting(models.Model): name = models.CharField(max_length=255) value = models.CharField(max_length=255) diff --git a/chat/serializers.py b/chat/serializers.py index a5ca424..fb677e3 100644 --- a/chat/serializers.py +++ b/chat/serializers.py @@ -1,5 +1,5 @@ from rest_framework import serializers -from .models import Conversation, Message +from .models import Conversation, Message, Prompt class ConversationSerializer(serializers.ModelSerializer): class Meta: @@ -10,3 +10,9 @@ class MessageSerializer(serializers.ModelSerializer): class Meta: model = Message fields = ['parent_message', 'message', 'is_bot', 'created_at'] + + +class PromptSerializer(serializers.ModelSerializer): + class Meta: + model = Prompt + fields = ['id', 'prompt', 'created_at', 'updated_at'] \ No newline at end of file diff --git a/chat/urls.py b/chat/urls.py index 8e965be..ccab8bc 100644 --- a/chat/urls.py +++ b/chat/urls.py @@ -1,10 +1,11 @@ from django.urls import include, path from rest_framework import routers -from .views import ConversationViewSet, MessageViewSet +from .views import ConversationViewSet, MessageViewSet, PromptViewSet router = routers.SimpleRouter() router.register(r'conversations', ConversationViewSet, basename='conversationModel') router.register(r'messages', MessageViewSet, basename='messageModel') +router.register(r'prompts', PromptViewSet, basename='promptModel') # Wire up our API using automatic URL routing. # Additionally, we include login URLs for the browsable API. diff --git a/chat/views.py b/chat/views.py index dbcc069..e278735 100644 --- a/chat/views.py +++ b/chat/views.py @@ -3,7 +3,7 @@ import openai import datetime import tiktoken -from .models import Conversation, Message, Setting +from .models import Conversation, Message, Setting, Prompt from django.conf import settings from django.http import StreamingHttpResponse from rest_framework import viewsets, status @@ -11,7 +11,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework.decorators import api_view, authentication_classes, permission_classes, action -from .serializers import ConversationSerializer, MessageSerializer +from .serializers import ConversationSerializer, MessageSerializer, PromptSerializer class ConversationViewSet(viewsets.ModelViewSet): @@ -39,6 +39,31 @@ def get_queryset(self): 'created_at') +class PromptViewSet(viewsets.ModelViewSet): + serializer_class = PromptSerializer + # authentication_classes = [JWTAuthentication] + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return Prompt.objects.filter(user=self.request.user).order_by('-created_at') + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + serializer.validated_data['user'] = request.user + + self.perform_create(serializer) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + @action(detail=False, methods=['delete']) + def delete_all(self, request): + queryset = self.filter_queryset(self.get_queryset()) + queryset.delete() + return Response(status=204) + + def sse_pack(event, data): # Format data as an SSE message packet = "event: %s\n" % event