Skip to content

Add unit tests for inference function #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .github/workflows/test-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
chart_validation:
needs: [publish_images]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
chart: ["azimuth-llm-chat", "azimuth-llm-image-analysis", "azimuth-llm", "flux-image-gen"]
env:
CLUSTER_NAME: chart-testing
RELEASE_NAME: ci-test
Expand All @@ -59,13 +63,16 @@ jobs:
uses: helm/chart-testing-action@v2

- name: Run chart linting
run: ct lint --config ct.yaml
run: ct lint --config ct.yaml --charts charts/${{ matrix.chart }}

- name: Create Kind Cluster
uses: helm/kind-action@v1
with:
cluster_name: ${{ env.CLUSTER_NAME }}

- name: Debug
run: (df -h && docker ps -a && docker image ls && kubectl get nodes && kubectl get pods --all-namespaces) || true

# NOTE(scott): Since the local Chart.yaml uses "appVersion: latest" and this
# only gets overwritten to the correct commit SHA during Helm chart build,
# we need to pull these published images and load them into the kind cluster
Expand All @@ -74,11 +81,14 @@ jobs:
run: ./kind-images.sh $(git rev-parse --short ${{ github.event.pull_request.head.sha }}) ${{ env.CLUSTER_NAME }}
working-directory: web-apps

- name: Debug
run: (df -h && docker ps -a && docker image ls && kubectl get nodes && kubectl get pods --all-namespaces) || true

# https://github.com/helm/charts/blob/master/test/README.md#providing-custom-test-values
# Each chart/ci/*-values.yaml file will be treated as a separate test case with it's
# own helm install/test process.
- name: Run chart install and test
run: ct install --config ct.yaml
run: ct install --config ct.yaml --charts charts/${{ matrix.chart }}

publish_charts:
needs: [chart_validation]
Expand Down
4 changes: 0 additions & 4 deletions ct.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Complains about invalid maintainer URLs
validate-maintainers: false
# Skip version bump detection and lint all charts
# since we're using the azimuth-cloud Helm chart publish
# workflow which doesn't use Chart.yaml's version key
all: true
# Split output to make it look nice in GitHub Actions tab
github-groups: true
# Allow for long running install and test processes
Expand Down
2 changes: 2 additions & 0 deletions web-apps/chat/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ARG DIR=chat

COPY $DIR/requirements.txt requirements.txt
COPY utils utils
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh purge-google-fonts.sh
Expand Down
43 changes: 25 additions & 18 deletions web-apps/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,37 @@ class PossibleSystemPromptException(Exception):
streaming=True,
)

def build_chat_context(latest_message, history):
"""
Build the chat context from the latest message and history.
"""
context = []
if INCLUDE_SYSTEM_PROMPT:
context.append(SystemMessage(content=settings.model_instruction))
else:
# Mimic system prompt by prepending it to first human message
history[0]['content'] = f"{settings.model_instruction}\n\n{history[0]['content']}"

for message in history:
role = message['role']
content = message['content']
if role == "user":
context.append(HumanMessage(content=content))
else:
if role != "assistant":
log.warn(f"Message role {role} converted to 'assistant'")
context.append(AIMessage(content=(content or "")))
context.append(HumanMessage(content=latest_message))
return context


def inference(latest_message, history):
# Allow mutating global variable
global BACKEND_INITIALISED
log.debug("Inference request received with history: %s", history)

try:
context = []
if INCLUDE_SYSTEM_PROMPT:
context.append(SystemMessage(content=settings.model_instruction))
else:
# Mimic system prompt by prepending it to first human message
history[0]['content'] = f"{settings.model_instruction}\n\n{history[0]['content']}"

for message in history:
role = message['role']
content = message['content']
if role == "user":
context.append(HumanMessage(content=content))
else:
if role != "assistant":
log.warn(f"Message role {role} converted to 'assistant'")
context.append(AIMessage(content=(content or "")))
context.append(HumanMessage(content=latest_message))

context = build_chat_context(latest_message, history)
log.debug("Chat context: %s", context)

response = ""
Expand Down
187 changes: 178 additions & 9 deletions web-apps/chat/test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,192 @@
import openai
import os
import unittest

# from unittest import mock
from gradio_client import Client
from unittest.mock import patch, MagicMock, Mock
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from app import build_chat_context, inference, PossibleSystemPromptException, gr

url = os.environ.get("GRADIO_URL", "http://localhost:7860")
client = Client(url)
latest_message = "Why don't humans drink horse milk?"
history = [
{
"role": "user",
"metadata": None,
"content": "Hi!",
"options": None,
},
{
"role": "assistant",
"metadata": None,
"content": "Hello! How can I help you?",
"options": None,
},
]

class TestSuite(unittest.TestCase):

class TestAPI(unittest.TestCase):
def test_gradio_api(self):
result = client.predict("Hi", api_name="/chat")
self.assertGreater(len(result), 0)

# def test_mock_response(self):
# with mock.patch('app.client.stream_response', return_value=(char for char in "Mocked")) as mock_response:
# result = client.predict("Hi", api_name="/chat")
# # mock_response.assert_called_once_with("Hi", [])
# self.assertEqual(result, "Mocked")
class TestBuildChatContext(unittest.TestCase):
@patch("app.settings")
@patch("app.INCLUDE_SYSTEM_PROMPT", True)
def test_chat_context_system_prompt(self, mock_settings):
mock_settings.model_instruction = "You are a helpful assistant."

context = build_chat_context(latest_message, history)

self.assertEqual(len(context), 4)
self.assertIsInstance(context[0], SystemMessage)
self.assertEqual(context[0].content, "You are a helpful assistant.")
self.assertIsInstance(context[1], HumanMessage)
self.assertEqual(context[1].content, history[0]["content"])
self.assertIsInstance(context[2], AIMessage)
self.assertEqual(context[2].content, history[1]["content"])
self.assertIsInstance(context[3], HumanMessage)
self.assertEqual(context[3].content, latest_message)

@patch("app.settings")
@patch("app.INCLUDE_SYSTEM_PROMPT", False)
def test_chat_context_human_prompt(self, mock_settings):
mock_settings.model_instruction = "You are a very helpful assistant."

context = build_chat_context(latest_message, history)

self.assertEqual(len(context), 3)
self.assertIsInstance(context[0], HumanMessage)
self.assertEqual(context[0].content, "You are a very helpful assistant.\n\nHi!")
self.assertIsInstance(context[1], AIMessage)
self.assertEqual(context[1].content, history[1]["content"])
self.assertIsInstance(context[2], HumanMessage)
self.assertEqual(context[2].content, latest_message)

class TestInference(unittest.TestCase):
@patch("app.settings")
@patch("app.llm")
@patch("app.log")
def test_inference_success(self, mock_logger, mock_llm, mock_settings):
mock_llm.stream.return_value = [MagicMock(content="response_chunk")]

mock_settings.model_instruction = "You are a very helpful assistant."

responses = list(inference(latest_message, history))

self.assertEqual(responses, ["response_chunk"])
mock_logger.debug.assert_any_call("Inference request received with history: %s", history)

@patch("app.llm")
@patch("app.build_chat_context")
def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_llm.stream.return_value = [
MagicMock(content="<think>"),
MagicMock(content="processing"),
MagicMock(content="</think>"),
MagicMock(content="final response"),
]

responses = list(inference(latest_message, history))

self.assertEqual(responses, ["Thinking...", "Thinking...", "", "final response"])

@patch("app.llm")
@patch("app.INCLUDE_SYSTEM_PROMPT", True)
@patch("app.build_chat_context")
@patch("app.log")
def test_inference_PossibleSystemPromptException(self, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_response = Mock()
mock_response.json.return_value = {"message": "Bad request"}

mock_llm.stream.side_effect = openai.BadRequestError(
message="Bad request",
response=mock_response,
body=None
)

with self.assertRaises(PossibleSystemPromptException):
list(inference(latest_message, history))
mock_logger.error.assert_called_once_with("Received BadRequestError from backend API: %s", mock_llm.stream.side_effect)

@patch("app.llm")
@patch("app.INCLUDE_SYSTEM_PROMPT", False)
@patch("app.build_chat_context")
@patch("app.log")
def test_inference_general_error(self, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_response = Mock()
mock_response.json.return_value = {"message": "Bad request"}

mock_llm.stream.side_effect = openai.BadRequestError(
message="Bad request",
response=mock_response,
body=None
)

exception_message = "\'API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: Bad request\'"

with self.assertRaises(gr.Error) as gradio_error:
list(inference(latest_message, history))
self.assertEqual(str(gradio_error.exception), exception_message)
mock_logger.error.assert_called_once_with("Received BadRequestError from backend API: %s", mock_llm.stream.side_effect)

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.log")
@patch("app.gr")
@patch("app.BACKEND_INITIALISED", False)
def test_inference_APIConnectionError(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.APIConnectionError(
message="Foo",
request=mock_request,
)

list(inference(latest_message, history))
mock_logger.info.assert_any_call("Backend API not yet ready")
mock_gr.Info.assert_any_call("Backend not ready - model may still be initialising - please try again later.")

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.log")
@patch("app.gr")
@patch("app.BACKEND_INITIALISED", True)
def test_inference_APIConnectionError_initialised(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.APIConnectionError(
message="Foo",
request=mock_request,
)

list(inference(latest_message, history))
mock_logger.error.assert_called_once_with("Failed to connect to backend API: %s", mock_llm.stream.side_effect)
mock_gr.Warning.assert_any_call("Failed to connect to backend API.")

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.gr")
def test_inference_InternalServerError(self, mock_gr, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.InternalServerError(
message="Foo",
response=mock_request,
body=None
)

list(inference(latest_message, history))
mock_gr.Warning.assert_any_call("Internal server error encountered in backend API - see API logs for details.")

if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
2 changes: 2 additions & 0 deletions web-apps/flux-image-gen/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
ARG DIR=flux-image-gen

COPY $DIR/requirements.txt requirements.txt
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh .
Expand Down
2 changes: 2 additions & 0 deletions web-apps/image-analysis/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ARG DIR=image-analysis

COPY $DIR/requirements.txt requirements.txt
COPY utils utils
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh purge-google-fonts.sh
Expand Down
2 changes: 1 addition & 1 deletion web-apps/test-images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test() {
--name $1-test-suite \
-e GRADIO_URL=http://$1-app:7860 --entrypoint python \
$IMAGE \
test.py
test.py -v

log "Removing containers:"
docker rm -f ollama $1-app
Expand Down
Loading