Install dependencies:
go mod tidy
Install mockgen:
go install go.uber.org/mock/[email protected]
go get go.uber.org/mock/mockgen/[email protected]
Run tests:
make test
Run on local machine:
make server
This is the key component of our ML service. The system for model serving has two layers:
- The serving platform, e.g., KServe, Replicate, RunPod or Kubernetes deployment.
- The serving agent (this repo) on Kubernetes, e.g., GKE and EKS.
The serving agent offers sync and async prediction APIs, redirecting the requests to the underlying ML platforms (e.g., KServe, Replicate), and providing task queues for long-running predictions. The agent requires a Redis or Redis cluster for the task queue, and a webhook service implemented in this repo for updating prediction status and results.
The serving agent provides both sync and async APIs. For the sync prediction API, when the agent receives the request, it will do the followings:
- Check if the request format is valid.
- Create a new task record via the webhook service.
- Send the request to the underlying ML platform and wait for the prediction results.
- If the prediction succeeded, update the task record via the webhook service, and return the results.
- If the prediction failed, update the task status to
failed
. - If any webhook call failed, return an error.
For the async prediction API, we utilize the asynq lib. For the asynq client:
- Check if the request format is valid.
- Check if the task queue is full. If the queue is full, return an error.
- Create a new task record via the webhook service.
- Submit the prediction task to the asynq task queue.
- Return the task ID if all these steps succeeded.
- If any step failed, it will update the task record status to
failed
and return an error.
For the asynq server:
- Pull a task from the task queue and send it to an available worker.
- The worker verifies the payload format.
- The worker changes the task status to
running
. - The worker sends the request to the underlying ML platform and waits for the prediction results.
- If the prediction succeeded, update the task record via the webhook service.
- If the prediction failed, update the task status to
failed
. - If any webhook call failed, raise an error and retry the task in the future.
The asynq queue can either use a redis in local memory or a redis cluster on Cloud, which depends on whether high availability is required. When an agent is terminated by k8s (either doing upgrading or rescheduling), we should shut down the agent service in different ways according to the redis type:
- If we use a redis cluster on Cloud, we just need to call
Shutdown
ofRedisTaskProcessor
. This method will wait for the active running task to be finished and shut down the asynq server. If the running task doesn't finish inTASK_TIMEOUT
seconds, it will be put back to the queue. - If we use a local redis, we call
ShutdownDistributor
.ShutdownDistributor
will wait forSHUTDOWN_DELAY
seconds first, and then wait for the active running task to be finished and set the other tasks (scheduled, pending or retry) in the queue tofailed
.
Because of this termination issue, we need to set terminationGracePeriodSeconds
to a proper value, i.e.,
greater than SHUTDOWN_DELAY
.
If the agent or local redis fails, it is possible that some tasks remain "pending/running" stored in the database. This can happen when task creation succeeds but the other steps are not executed due to node failure or other unknown errors. This happens rarely, but it's still an issue needed to be resolved. To handle this issue, we do the following checks:
- The service will check if there exists archived tasks in the queue periodically (e.g., every 30mins).
If some tasks are archived, their status will be set to
failed
. - Before starting the service, we will check if there are
pending
orrunning
tasks recorded in the database by callingInitCheck
in theworker
package. This function will handle remainingpending
orrunning
tasks.
The key APIs:
API | Description | Method | Input data (JSON format) |
---|---|---|---|
/v1/predict | The sync prediction API | POST | {"model_name": "model", "inputs": {<MODEL_INPUTS>}} |
/async/v1/predict | The async prediction API | POST | {"model_name": "model", "inputs": {<MODEL_INPUTS>}} |
/task/{ID} | Get the task information | GET | NA |
/cancel/{ID} | Cancel a pending task | POST | NA |
Here are the key parameters:
Parameter | Description | Sample value |
---|---|---|
HTTP_SERVER_ADDRESS | The TCP address for the server to listen on | 0.0.0.0:8000 |
REDIS_ADDRESS | The redis server address for Asynq | 0.0.0.0:6379 |
REDIS_CLUSTER_MODE | Whether it is a redis cluster | False |
WORKER_CONCURRENCY | The number of workers for Asynq | 8,64 or more |
MAX_QUEUE_SIZE | The maximum number of scheduled, pending and retry tasks | 10 |
ML_PLATFORM | Which ML platform to use | kserve, k8s, replicate, runpod |
WEBHOOK_SERVER_ADDRESS | The serving webhook address | 0.0.0.0:12000 |
UPLOAD_WEBHOOK_ADDRESS | The webhook for uploading images or files | 0.0.0.0:12000 |
SHUTDOWN_DELAY | The server will wait for SHUTDOWN_DELAY seconds after receiving SIGTERM | 340 |
TASK_TIMEOUT | The timeout of a prediction task | 320 |
The followings are the other parameters depending on which ML platform to use. For KServe:
Parameter | Description | Sample value |
---|---|---|
KSERVE_VERSION | The KServe version | 0.10.2 |
KSERVE_ADDRESS | The KServe address | 0.0.0.0:8080 |
KSERVE_CUSTOM_DOMAIN | The custom domain | example.com |
KSERVE_NAMESPACE | The namespace where the model is deployed | default |
KSERVE_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |
For Replicate:
Parameter | Description | Sample value |
---|---|---|
REPLICATE_ADDRESS | The Replicate API address | https://api.replicate.com/v1/predictions |
REPLICATE_APIKEY | The Replicate API key | xxxxx |
REPLICATE_MODEL_ID | The model ID | xxxxx |
REPLICATE_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |
For RunPod:
Parameter | Description | Sample value |
---|---|---|
RUNPOD_ADDRESS | The RunPod API address | https://api.runpod.ai/v2 |
RUNPOD_APIKEY | The RunPod API key | xxxxx |
RUNPOD_MODEL_ID | The model ID | xxxxx |
RUNPOD_REQUEST_TIMEOUT | The timeout for a prediction request | 180 |