Skip to content

Commit

Permalink
Production ready version
Browse files Browse the repository at this point in the history
  • Loading branch information
yangwenz committed Jun 2, 2024
1 parent 0b1ea90 commit 3a73586
Show file tree
Hide file tree
Showing 35 changed files with 4,592 additions and 1 deletion.
15 changes: 15 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Build stage
FROM golang:1.21-alpine3.18 AS builder
WORKDIR /app
COPY . .
RUN go build -o main main.go

# Run stage
FROM alpine:3.18
WORKDIR /app

COPY --from=builder /app/main .
COPY app.env .

EXPOSE 8000
ENTRYPOINT [ "/app/main" ]
25 changes: 25 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

server:
go run main.go

test:
go test -v -cover -short ./...

mock:
mockgen -package mockplatform -destination platform/mock/platform.go github.com/HyperGAI/serving-agent/platform Platform
mockgen -package mockwk -destination worker/mock/distributor.go github.com/HyperGAI/serving-agent/worker TaskDistributor
mockgen -package mockplatform -destination platform/mock/webhook.go github.com/HyperGAI/serving-agent/platform Webhook,Fetcher

docker:
docker build --platform=linux/amd64 -t yangwenz/serving-agent:latest .
docker push yangwenz/serving-agent:latest

redis:
service redis stop
docker run --name redis -p 6379:6379 -d redis:7.2-alpine

dropredis:
docker stop redis
docker container rm redis

.PHONY: server test mock docker redis dropredis
162 changes: 161 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,162 @@
# serving-agent
The agent for model serving

## Setup

Install dependencies:

```shell
go mod tidy
```

Install mockgen:

```shell
go install go.uber.org/mock/[email protected]
go get go.uber.org/mock/mockgen/[email protected]
```

Run tests:

```shell
make test
```

Run on local machine:

```shell
make server
```

## Overview

This is the key component of our ML service. The system for model serving has two layers:

1. The serving platform, e.g., KServe, Replicate, RunPod or Kubernetes deployment.
2. The serving agent (this repo) on GKE or EKS.

The role of a serving agent is to offer sync/async prediction APIs, redirect the requests to the underlying
ML platforms (either KServe or Replicate), and provide task queues for long-running predictions.
The agent requires a Redis or Redis cluster for the task queue, and the serving webhook implemented in this
[repo](https://github.com/HyperGAI/serving-webhook) for updating prediction status and results.

### Sync API

The serving agent provides both Sync and Async APIs. For the Sync prediction API, when the agent receives
the request, it will do the followings:

1. Check if the request format is valid.
2. Create a new task record via the serving webhook.
3. Send the request to the underlying ML platform and wait for the prediction results.
4. If the prediction succeeded, update the task record via the serving webhook, and return the results.
5. If the prediction failed, update the task status to `failed`.
6. If any webhook call failed, return the error.

### Async API

For the Async prediction API, we utilize the asynq lib. For the asynq client:

1. Check if the request format is valid.
2. Check if the task queue is full. If the queue is full, return the error.
3. Create a new task record via the serving webhook.
4. Submit the prediction task to the asynq task queue.
5. Return the task ID if all these steps succeeded.
6. If any step failed, it will update the task record status to `failed` and return the error.

For the asynq server:

1. Pull a task from the task queue and send it to one available worker.
2. The worker verifies the payload format.
3. The worker changes the task status to `running`.
4. The worker sends the request to the underlying ML platform and waits for the prediction results.
5. If the prediction succeeded, update the task record via the serving webhook.
6. If the prediction failed, update the task status to `failed`.
7. If any webhook call failed, raise an error and retry the task in the future.

### Graceful Termination

The asynq queue can either use a redis in local memory or a redis cluster on GCP, which depends on
whether high availability is required. When an agent is terminated by k8s (either doing upgrading or rescheduling),
we should shut down the agent service in different ways according to the redis type:

1. If we use a redis cluster on GCP, we just need to call `Shutdown` of `RedisTaskProcessor`. This method
will wait for the active running task to be finished and shut down the asynq server. If the running task
doesn't finish in `TASK_TIMEOUT` seconds, it will be put back to the queue.
2. If we use a local redis, we call `ShutdownDistributor`.
`ShutdownDistributor` will wait for `SHUTDOWN_DELAY` seconds first, and then wait for the active running task
to be finished and set the other tasks (scheduled, pending or retry) in the queue to `failed`.

Because of this termination issue, we need to set `terminationGracePeriodSeconds` to a proper value, i.e.,
greater than `SHUTDOWN_DELAY`.

### Node Failure and Unknown Errors

If the agent or local redis fails, it is possible that some tasks remain "pending/running" stored in the database.
This can happen when task creation succeeds but the other steps are not executed
due to node failure or other unknown errors. This happens rarely, but it's still an issue needed to be
resolved. To handle this issue, we do the following checks:
1. The service will check if there exists archived tasks in the queue periodically (every 10mins).
If some tasks are archived, their status will be set to `failed`.
2. Before starting the service, we will check if there are `pending` or `running` tasks recorded in the database
by calling `InitCheck` in the `worker` package. This function will handle remaining `pending` or `running` tasks.

## API Definition

The key APIs:

| API | Description | Method | Input data (JSON format) |
:-----------------:|:------------------------:|:------:|:---------------------------------------------------:
| /v1/predict | The sync prediction API | POST | {"model_name": "model", "inputs": {<MODEL_INPUTS>}} |
| /async/v1/predict | The async prediction API | POST | {"model_name": "model", "inputs": {<MODEL_INPUTS>}} |
| /task/{ID} | Get the task information | GET | NA |
| /cancel/{ID} | Cancel a pending task | POST | NA |

## Parameter Settings

Here are the key parameters:

| Parameter | Description | Sample value |
:----------------------:|:-----------------------------------------------------------------------:|:------------------------------:
| HTTP_SERVER_ADDRESS | The TCP address for the server to listen on | 0.0.0.0:8000 |
| REDIS_ADDRESS | The redis server address for Asynq | 0.0.0.0:6379 |
| REDIS_CLUSTER_MODE | Whether it is a redis cluster | False |
| WORKER_CONCURRENCY | The number of workers for Asynq | 8,64 or more |
| MAX_QUEUE_SIZE | The maximum number of scheduled, pending and retry tasks | 10 |
| ML_PLATFORM | Which ML platform to use | kserve, k8s, replicate, runpod |
| WEBHOOK_SERVER_ADDRESS | The serving webhook address | 0.0.0.0:12000 |
| UPLOAD_WEBHOOK_ADDRESS | The webhook for uploading images or files | 0.0.0.0:12000 |
| SHUTDOWN_DELAY | The server will wait for SHUTDOWN_DELAY seconds after receiving SIGTERM | 340 |
| TASK_TIMEOUT | The timeout of a prediction task | 320 |

The followings are the other parameters depending on which ML platform to use. For KServe:

| Parameter | Description | Sample value |
:----------------------:|:-----------------------------------------:|:------------:
| KSERVE_ADDRESS | The KServe address | 0.0.0.0:8080 |
| KSERVE_CUSTOM_DOMAIN | The custom domain | example.com |
| KSERVE_NAMESPACE | The namespace where the model is deployed | default |
| KSERVE_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |

For Replicate:

| Parameter | Description | Sample value |
:-------------------------:|:------------------------------------:|:----------------------------------------:
| REPLICATE_ADDRESS | The Replicate API address | https://api.replicate.com/v1/predictions |
| REPLICATE_APIKEY | The Replicate API key | xxxxx |
| REPLICATE_MODEL_ID | The model ID | xxxxx |
| REPLICATE_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |

For RunPod:

| Parameter | Description | Sample value |
:----------------------:|:------------------------------------:|:------------------------:
| RUNPOD_ADDRESS | The RunPod API address | https://api.runpod.ai/v2 |
| RUNPOD_APIKEY | The RunPod API key | xxxxx |
| RUNPOD_MODEL_ID | The model ID | xxxxx |
| RUNPOD_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |

For K8S deployment:

| Parameter | Description | Sample value |
:-------------------------:|:------------------------------------:|:------------:
| K8SPLUGIN_ADDRESS | The k8s serving plugin address | 0.0.0.0:8002 |
| K8SPLUGIN_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |
28 changes: 28 additions & 0 deletions api/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package api

import (
"github.com/HyperGAI/serving-agent/platform"
"github.com/HyperGAI/serving-agent/utils"
"github.com/HyperGAI/serving-agent/worker"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"os"
"testing"
)

func newTestServer(
t *testing.T,
platform platform.Platform,
distributor worker.TaskDistributor,
webhook platform.Webhook,
) *Server {
config := utils.Config{MaxQueueSize: 300}
server, err := NewServer(config, platform, distributor, webhook)
require.NoError(t, err)
return server
}

func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode)
os.Exit(m.Run())
}
57 changes: 57 additions & 0 deletions api/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package api

import (
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"strconv"
)

var totalRequests = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Number of requests",
},
[]string{"path"},
)

var responseStatus = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "response_status",
Help: "Status of HTTP response",
},
[]string{"path", "status"},
)

var httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "http_response_time_seconds",
Help: "Duration of HTTP requests",
}, []string{"path"})

var queueSizeGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "task_queue_size",
Help: "The size of the task queue",
})

var queueSizeRatioGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "task_queue_size_ratio",
Help: "The ratio of the queue size",
})

var taskStatusSetToFailedGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "num_tasks_set_to_failed",
Help: "The number of the tasks set to failed by CheckTaskStatus",
})

func prometheusMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
path := ctx.FullPath()
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
// Call the next middleware or endpoint handler
ctx.Next()
// Update metrics
totalRequests.WithLabelValues(path).Inc()
responseStatus.WithLabelValues(path, strconv.Itoa(ctx.Writer.Status())).Inc()
timer.ObserveDuration()
}
}
Loading

0 comments on commit 3a73586

Please sign in to comment.