Skip to content

Commit

Permalink
RAG index and query validation
Browse files Browse the repository at this point in the history
Signed-off-by: Bangqi Zhu <[email protected]>
  • Loading branch information
Bangqi Zhu committed Feb 10, 2025
1 parent f5fb284 commit 8beda3c
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 3 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/ragengine-e2e.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: ragengine-e2e-test

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

on:
pull_request:
paths-ignore: ['docs/**', '**.md', '**.mdx', '**.png', '**.jpg']

env:
GO_VERSION: "1.22"

permissions:
id-token: write # This is required for requesting the JWT
contents: read # This is required for actions/checkout

jobs:
run-e2e:
strategy:
fail-fast: false
matrix:
node-provisioner: [gpuprovisioner] # WIP: azkarpenter]
permissions:
contents: read
id-token: write
statuses: write
uses: ./.github/workflows/ragengine-e2e-workflow.yml
with:
git_sha: ${{ github.event.pull_request.head.sha }}
node_provisioner: ${{ matrix.node-provisioner }}
2 changes: 1 addition & 1 deletion pkg/ragengine/controllers/preset-rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func CreatePresetRAG(ctx context.Context, ragEngineObj *v1alpha1.RAGEngine, revi
}
commands := utils.ShellCmd("python3 main.py")
// TODO: provide this image
image := "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1"
image := "aimodelsregistrytest.azurecr.io/kaito-ragengine:0.0.1"

imagePullSecretRefs := []corev1.LocalObjectReference{}

Expand Down
148 changes: 146 additions & 2 deletions test/rage2e/rag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math/rand"
"os"
"strings"
"time"

. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -105,6 +106,9 @@ var _ = Describe("RAGEngine", func() {
validateInferenceandRAGResource(ragengineObj.ObjectMeta, int32(numOfReplica), false)
validateRAGEngineCondition(ragengineObj, string(kaitov1alpha1.RAGEngineConditionTypeSucceeded), "ragengine to be ready")

createIndexPod(ragengineObj)
createAndValidateQueryPod(ragengineObj)

})

It("should create RAG with localembedding and huggingface API successfully", func() {
Expand All @@ -119,6 +123,9 @@ var _ = Describe("RAGEngine", func() {
validateInferenceandRAGResource(ragengineObj.ObjectMeta, int32(numOfReplica), false)
validateRAGEngineCondition(ragengineObj, string(kaitov1alpha1.RAGEngineConditionTypeSucceeded), "ragengine to be ready")

createIndexPod(ragengineObj)
//TODO: add the createAndValidateQueryPod here in the next PR

})
})

Expand Down Expand Up @@ -210,7 +217,7 @@ func createLocalEmbeddingKaitoVLLMRAGEngine(baseURL string) *kaitov1alpha1.RAGEn
serviceURL := fmt.Sprintf("http://%s/v1/completions", baseURL)
By("Creating RAG with localembedding and kaito vllm inference", func() {
uniqueID := fmt.Sprint("rag-", rand.Intn(1000))
ragEngineObj = GenerateLocalEmbeddingRAGEngineManifest(uniqueID, namespaceName, "Standard_NC6s_v3", "BAAI/bge-small-en-v1.5",
ragEngineObj = GenerateLocalEmbeddingRAGEngineManifest(uniqueID, namespaceName, "Standard_NC24s_v3", "BAAI/bge-small-en-v1.5",
&metav1.LabelSelector{
MatchLabels: map[string]string{"apps": "phi-3"},
},
Expand All @@ -229,7 +236,7 @@ func createLocalEmbeddingHFURLRAGEngine() *kaitov1alpha1.RAGEngine {
hfURL := "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta/v1/completions"
By("Creating RAG with localembedding and huggingface API", func() {
uniqueID := fmt.Sprint("rag-", rand.Intn(1000))
ragEngineObj = GenerateLocalEmbeddingRAGEngineManifest(uniqueID, namespaceName, "Standard_NC6s_v3", "BAAI/bge-small-en-v1.5",
ragEngineObj = GenerateLocalEmbeddingRAGEngineManifest(uniqueID, namespaceName, "Standard_NC12s_v3", "BAAI/bge-small-en-v1.5",
&metav1.LabelSelector{
MatchLabels: map[string]string{"apps": "phi-3"},
},
Expand Down Expand Up @@ -440,3 +447,140 @@ func deleteWorkspace(workspaceObj *kaitov1alpha1.Workspace) error {

return nil
}

func createIndexPod(ragengineObj *kaitov1alpha1.RAGEngine) error {
By("Creating index pod", func() {
pod := GenerateIndexPodManifest(ragengineObj.Namespace, ragengineObj.Name)
Eventually(func() error {
return utils.TestingCluster.KubeClient.Create(ctx, pod, &client.CreateOptions{})
}, utils.PollTimeout, utils.PollInterval).
Should(Succeed(), "Failed to create index pod")
})
time.Sleep(60 * time.Second)

return nil
}

func createAndValidateQueryPod(ragengineObj *kaitov1alpha1.RAGEngine) error {
By("Creating query pod", func() {
pod := GenerateQueryPodManifest(ragengineObj.Namespace, ragengineObj.Name)
Eventually(func() error {
return utils.TestingCluster.KubeClient.Create(ctx, pod, &client.CreateOptions{})
}, utils.PollTimeout, utils.PollInterval).
Should(Succeed(), "Failed to create query pod")
})

By("Checking the query logs", func() {
Eventually(func() bool {
coreClient, err := utils.GetK8sClientset()
if err != nil {
GinkgoWriter.Printf("Failed to create core client: %v\n", err)
return false
}

logs, err := utils.GetPodLogs(coreClient, ragengineObj.Namespace, "querypod", "")
if err != nil {
GinkgoWriter.Printf("Failed to get logs from pod %s: %v\n", "querypod", err)
return false
}

searchQuerySuccess := "'text': '\\nKaito is an operator that automates the AI/ML model inference or tuning workload in a Kubernetes cluster.\\n"

GinkgoWriter.Printf("Expected (len=%d): %q\n", len(searchQuerySuccess), searchQuerySuccess)

if idx := strings.Index(logs, "'text'"); idx != -1 {
actualText := logs[idx:min(idx+200, len(logs))]
GinkgoWriter.Printf("Actual (len=%d): %q\n", len(actualText), actualText)

for i := 0; i < min(len(searchQuerySuccess), len(actualText)); i++ {
if searchQuerySuccess[i] != actualText[i] {
GinkgoWriter.Printf("First difference at position %d: expected=%q, got=%q\n",
i, string(searchQuerySuccess[i]), string(actualText[i]))
break
}
}
}
GinkgoWriter.Print(strings.Contains(logs, searchQuerySuccess))
return true
}, 2*time.Minute, utils.PollInterval).Should(BeTrue(), "Failed to wait for query logs to be ready")
})

return nil
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

func GenerateIndexPodManifest(namespace, serviceName string) *v1.Pod {

curlCommand := `curl -X POST ` + serviceName + `:80/index \
-H "Content-Type: application/json" \
-d '{
"index_name": "kaito",
"documents": [
{
"text": "Kaito is an operator that automates the AI/ML model inference or tuning workload in a Kubernetes cluster",
"metadata": {"author": "kaito", "category": "kaito"}
}
]
}'`

indexPod := &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "indexpod",
Namespace: namespace,
},
Spec: v1.PodSpec{
RestartPolicy: v1.RestartPolicyNever,
Containers: []v1.Container{
{
Name: "curl",
Image: "curlimages/curl:latest",
Command: []string{"/bin/sh", "-c"},
Args: []string{curlCommand},
},
},
},
}

return indexPod
}

func GenerateQueryPodManifest(namespace, serviceName string) *v1.Pod { // TODO: add another model param for the remote inference service in the next PR

curlCommand := `curl -X POST ` + serviceName + `:80/query \
-H "Content-Type: application/json" \
-d '{
"index_name": "kaito",
"model": "phi-3-mini-128k-instruct",
"query": "what is kaito?",
"llm_params": {
"max_tokens": 50,
"temperature": 0
}
}'`

indexPod := &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "querypod",
Namespace: namespace,
},
Spec: v1.PodSpec{
RestartPolicy: v1.RestartPolicyNever,
Containers: []v1.Container{
{
Name: "curl",
Image: "curlimages/curl:latest",
Command: []string{"/bin/sh", "-c"},
Args: []string{curlCommand},
},
},
},
}

return indexPod
}

0 comments on commit 8beda3c

Please sign in to comment.