Skip to content

Commit

Permalink
fixup! ✨(backend) create ai endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoLC committed Sep 26, 2024
1 parent fbaa4a7 commit a6d909b
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
47 changes: 47 additions & 0 deletions src/backend/core/api/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Client serializers for the impress core app."""

import json
import mimetypes
import re

from django.conf import settings
from django.db.models import Q
Expand Down Expand Up @@ -350,3 +352,48 @@ class DocumentVersionSerializer(serializers.Serializer):
is_latest = serializers.BooleanField()
last_modified = serializers.DateTimeField()
version_id = serializers.CharField()


class AIRequestSerializer(serializers.Serializer):
"""Serializer for AI task requests."""

ACTION_CHOICES = [
"prompt",
"correct",
"rephrase",
"summarize",
"translate_en",
"translate_de",
"translate_fr",
]

action = serializers.ChoiceField(choices=ACTION_CHOICES, required=True)
text = serializers.CharField(required=True)

def validate_text(self, value):
"""Ensure the text field is not empty."""

if len(value.strip()) == 0:
raise serializers.ValidationError("Text field cannot be empty.")
return value

def validate_action(self, value):
"""Ensure the action field is valid."""

if value not in self.ACTION_CHOICES:
raise serializers.ValidationError("Invalid action.")
return value

def process_ai_reponse(self, response):
"""Process the response from the AI service."""

content = response.choices[0].message.content
sanitized_content = re.sub(r"(?<!\\)\n", "\\\\n", content)
sanitized_content = re.sub(r"(?<!\\)\t", "\\\\t", sanitized_content)

json_response = json.loads(sanitized_content)

if "answer" not in json_response:
raise serializers.ValidationError("Invalid response format.")

return json_response
18 changes: 8 additions & 10 deletions src/backend/core/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,6 @@ def create(self, request):
translate_en, translate_de, translate_fr]
Return JSON response with the processed text.
"""
if not request.user.is_authenticated:
raise exceptions.NotAuthenticated()

if (
settings.AI_BASE_URL is None
Expand All @@ -797,8 +795,11 @@ def create(self, request):
):
raise exceptions.ValidationError({"error": "AI configuration not set"})

action = request.data.get("action")
text = request.data.get("text")
serializer = serializers.AIRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)

action = serializer.validated_data["action"]
text = serializer.validated_data["text"]

action_configs = {
"prompt": {
Expand Down Expand Up @@ -868,7 +869,7 @@ def create(self, request):

try:
client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY)
response = client.chat.completions.create(
response_client = client.chat.completions.create(
model=settings.AI_MODEL,
response_format={"type": "json_object"},
messages=[
Expand All @@ -877,12 +878,9 @@ def create(self, request):
],
)

corrected_response = json.loads(response.choices[0].message.content)

if "answer" not in corrected_response:
raise exceptions.ValidationError("Invalid response format")
response = serializer.process_ai_reponse(response_client)

return drf_response.Response(corrected_response, status=status.HTTP_200_OK)
return drf_response.Response(response, status=status.HTTP_200_OK)

except exceptions.ValidationError as e:
return drf_response.Response(
Expand Down
38 changes: 36 additions & 2 deletions src/backend/core/tests/test_api_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_api_ai__bad_action_config():
)

assert response.status_code == 400
assert response.json() == {"error": "Invalid action"}
assert response.json() == {"action": ['"bad_action" is not a valid choice.']}


@override_settings(
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_api_ai__client_invalid_response():
)

assert response.status_code == 400
assert response.json() == {"error": ["Invalid response format"]}
assert response.json() == {"error": ["Invalid response format."]}


@override_settings(
Expand Down Expand Up @@ -189,3 +189,37 @@ def test_api_ai__success():

assert response.status_code == 200
assert response.json() == {"answer": "Salut le monde"}


@override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
)
def test_api_ai__success_sanitize():
"""
Test the ai response is sanitized
"""
user = factories.UserFactory()

client = APIClient()
client.force_login(user)

with patch("openai.resources.chat.completions.Completions.create") as mock_create:
mock_create.return_value = ChatCompletionMock(
id="test-id",
choices=[
ChoiceMock(
message=MessageMock(content='{"answer": "Salut\\n \tle \nmonde"}')
)
],
)

response = client.post(
"/api/v1.0/ai/",
{
"action": "translate_fr",
"text": "Hello world",
},
)

assert response.status_code == 200
assert response.json() == {"answer": "Salut\n \tle \nmonde"}

0 comments on commit a6d909b

Please sign in to comment.