diff --git a/common/util.go b/common/util.go index 19fb0b6a9..159e30574 100644 --- a/common/util.go +++ b/common/util.go @@ -462,6 +462,21 @@ func ParseEthAddr(strJsonKey string) (string, error) { return "", errors.New("Error parsing address from keyfile") } +func GetInputVideoInfo(video types.File) (ffmpeg.MediaFormatInfo, error) { + bytearr, _ := video.Bytes() + _, mediaFormat, err := ffmpeg.GetCodecInfoBytes(bytearr) + if err != nil { + return ffmpeg.MediaFormatInfo{}, errors.New("Error getting codec info") + } + + duration := int64(mediaFormat.DurSecs) + if duration <= 0 { + return ffmpeg.MediaFormatInfo{}, errors.New("video duration calculation failed") + } + + return mediaFormat, nil +} + // CalculateAudioDuration calculates audio file duration using the lpms/ffmpeg package. func CalculateAudioDuration(audio types.File) (int64, error) { read, err := audio.Reader() diff --git a/core/ai.go b/core/ai.go index a9eeae9f7..939a2d44e 100644 --- a/core/ai.go +++ b/core/ai.go @@ -28,6 +28,7 @@ type AI interface { ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) LiveVideoToVideo(context.Context, worker.GenLiveVideoToVideoJSONRequestBody) (*worker.LiveVideoToVideoResponse, error) + ObjectDetection(context.Context, worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(string, string) bool diff --git a/core/ai_test.go b/core/ai_test.go index 3e4ab8207..6e9c07f78 100644 --- a/core/ai_test.go +++ b/core/ai_test.go @@ -667,6 +667,15 @@ func (a *stubAIWorker) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV return &worker.LiveVideoToVideoResponse{}, nil } +func (a *stubAIWorker) ObjectDetection(ctx context.Context, req worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) { + return &worker.ObjectDetectionResponse{ + Video: {Url: "http://example.com/frames1.mp4"}, + ConfidenceScores: "confidence_scores", + Labels: "labels", + DetectionBoxes: "detection_boxes", + DetectionPts: "detection_pts"}, nil +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { return nil } diff --git a/core/ai_worker.go b/core/ai_worker.go index 235338fca..241867d50 100644 --- a/core/ai_worker.go +++ b/core/ai_worker.go @@ -465,6 +465,21 @@ func (n *LivepeerNode) saveLocalAIWorkerResults(ctx context.Context, results int } resp.Audio.Url = osUrl + results = resp + case worker.ObjectDetectionResponse: + if resp.Video.Url != "" { + err := worker.ReadVideoB64DataUrl(resp.Video.Url, &buf) + if err != nil { + return nil, err + } + + osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0) + if err != nil { + return nil, err + } + resp.Video.Url = osUrl + } + results = resp } @@ -510,6 +525,19 @@ func (n *LivepeerNode) saveRemoteAIWorkerResults(ctx context.Context, results *R delete(results.Files, fileName) results.Results = resp + case worker.ObjectDetectionResponse: + if resp.Video.Url != "" { + fileName := resp.Video.Url + osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0) + if err != nil { + return nil, err + } + + resp.Video.Url = osUrl + delete(results.Files, fileName) + + results.Results = resp + } } // no file response to save, response is text @@ -884,6 +912,50 @@ func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, re return res.Results, nil } +func (orch *orchestrator) ObjectDetection(ctx context.Context, requestID string, req worker.GenObjectDetectionMultipartRequestBody) (interface{}, error) { + // local AIWorker processes job if combined orchestrator/ai worker + if orch.node.AIWorker != nil { + workerResp, err := orch.node.ObjectDetection(ctx, req) + if err == nil { + return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "video/mp4") + } else { + clog.Errorf(ctx, "Error processing with local ai worker err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "object-detection", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + } + + // remote ai worker proceses job + videoBytes, err := req.Video.Bytes() + if err != nil { + return nil, err + } + + inputUrl, err := orch.SaveAIRequestInput(ctx, requestID, videoBytes) + if err != nil { + return nil, err + } + req.Video.InitFromBytes(nil, "") + + res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "object-detection", *req.ModelId, inputUrl, AIJobRequestData{Request: req, InputUrl: inputUrl}) + if err != nil { + return nil, err + } + + res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID) + if err != nil { + clog.Errorf(ctx, "Error saving remote ai result err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "object-detection", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + + return res.Results, nil +} + // only used for sending work to remote AI worker func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) { node := orch.node @@ -1062,7 +1134,11 @@ func (n *LivepeerNode) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV return n.AIWorker.LiveVideoToVideo(ctx, req) } -// transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline. +func (n *LivepeerNode) ObjectDetection(ctx context.Context, req worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) { + return n.AIWorker.ObjectDetection(ctx, req) +} + +// transcodeFrames converts a series of image URLs into a video segment for the image-to-video and object-detection pipeline. func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult { ctx = clog.AddOrchSessionID(ctx, sessionID) diff --git a/core/capabilities.go b/core/capabilities.go index d2425fa98..a2199ce85 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -83,6 +83,7 @@ const ( Capability_ImageToText Capability = 34 Capability_LiveVideoToVideo Capability = 35 Capability_TextToSpeech Capability = 36 + Capability_ObjectDetection Capability = 37 ) var CapabilityNameLookup = map[Capability]string{ @@ -124,6 +125,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToText: "Image to text", Capability_LiveVideoToVideo: "Live video to video", Capability_TextToSpeech: "Text to speech", + Capability_ObjectDetection: "Object detection", } var CapabilityTestLookup = map[Capability]CapabilityTest{ @@ -217,6 +219,7 @@ func OptionalCapabilities() []Capability { Capability_SegmentAnything2, Capability_ImageToText, Capability_TextToSpeech, + Capability_ObjectDetection, } } diff --git a/go.mod b/go.mod index 2b7806e0c..d08c52eaf 100644 --- a/go.mod +++ b/go.mod @@ -257,3 +257,5 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/livepeer/ai-worker => github.com/RUFFY-369/ai-worker v0.8.1-0.20241102154421-60e5d350c2df diff --git a/go.sum b/go.sum index 14a5b8a05..741547eed 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0 github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/RUFFY-369/ai-worker v0.8.1-0.20241102154421-60e5d350c2df h1:UL+t2GVDDk20eypLFcbbqiO95OmlyzjARCkBUbvVNmc= +github.com/RUFFY-369/ai-worker v0.8.1-0.20241102154421-60e5d350c2df/go.mod h1:GjQuPmz69UO53WVtqzB9Ygok5MmKCGNuobbfMXH7zgw= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= diff --git a/server/ai_http.go b/server/ai_http.go index f738f3df0..812e8e64f 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -71,6 +71,7 @@ func startAIServer(lp *lphttp) error { lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody]))) lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody]))) lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo())) + lp.transRPC.Handle("/object-detection", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenObjectDetectionMultipartRequestBody]))) // Additionally, there is the '/aiResults' endpoint registered in server/rpc.go return nil @@ -470,6 +471,20 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request // TTS pricing is typically in characters, including punctuation. words := utf8.RuneCountInString(*v.Text) outPixels = int64(1000 * words) + case worker.GenObjectDetectionMultipartRequestBody: + pipeline = "object-detection" + cap = core.Capability_ObjectDetection + modelID = *v.ModelId + mediaFormat, err := common.GetInputVideoInfo(v.Video) + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + } + + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.ObjectDetection(ctx, requestID, v) + } + // Calculate the output pixels using the video profile + outPixels = int64(mediaFormat.Width) * int64(mediaFormat.Height) * int64(mediaFormat.FPS) * mediaFormat.DurSecs default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -575,6 +590,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request latencyScore = CalculateImageToTextLatencyScore(took, outPixels) case worker.GenTextToSpeechJSONRequestBody: latencyScore = CalculateTextToSpeechLatencyScore(took, outPixels) + case worker.GenObjectDetectionMultipartRequestBody: + latencyScore = CalculateObjectDetectionLatencyScore(took, outPixels) } var pricePerAIUnit float64 @@ -786,6 +803,16 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core } case "text-to-speech": var parsedResp worker.AudioResponse + err := json.Unmarshal(body, &parsedResp) + if err != nil { + glog.Error("Error getting results json:", err) + wkrResult.Err = err + break + } + results = parsedResp + case "object-detection": + var parsedResp worker.ObjectDetectionResponse + err := json.Unmarshal(body, &parsedResp) if err != nil { glog.Error("Error getting results json:", err) diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 80f92948b..877ad35d9 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -81,6 +81,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2))) ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText))) ls.HTTPMux.Handle("/text-to-speech", oapiReqValidator(aiMediaServerHandle(ls, jsonDecoder[worker.GenTextToSpeechJSONRequestBody], processTextToSpeech))) + ls.HTTPMux.Handle("/object-detection", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenObjectDetectionMultipartRequestBody], processObjectDetection))) // This is called by the media server when the stream is ready ls.HTTPMux.Handle("/live/video-to-video/{stream}/start", ls.StartLiveVideo()) diff --git a/server/ai_process.go b/server/ai_process.go index ea15bd43e..8465e41b9 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -38,6 +38,7 @@ const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" const defaultImageToTextModelID = "Salesforce/blip-image-captioning-large" const defaultLiveVideoToVideoModelID = "noop" const defaultTextToSpeechModelID = "parler-tts/parler-tts-large-v1" +const defaultObjectDetectionModelID = "PekingU/rtdetr_r50vd" var errWrongFormat = fmt.Errorf("result not in correct format") @@ -1348,6 +1349,142 @@ func processImageToText(ctx context.Context, params aiRequestParams, req worker. return txtResp, nil } +func CalculateObjectDetectionLatencyScore(took time.Duration, outPixels int64) float64 { + if outPixels <= 0 { + return 0 + } + + return took.Seconds() / float64(outPixels) +} + +func submitObjectDetection(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) { + var buf bytes.Buffer + mw, err := worker.NewObjectDetectionMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + mediaFormat, err := common.GetInputVideoInfo(req.Video) + if err != nil { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + + // Calculate the output pixels using the video profile + outPixels := int64(mediaFormat.Width) * int64(mediaFormat.Height) * int64(mediaFormat.FPS) * mediaFormat.DurSecs + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.GenObjectDetectionWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + took := time.Since(start) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if resp.StatusCode != 200 { + return nil, errors.New(string(data)) + } + + // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update + if balUpdate != nil { + balUpdate.Status = ReceivedChange + } + + var res worker.ObjectDetectionResponse + if err := json.Unmarshal(data, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "object-detection", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + // TODO: Refine this rough estimate in future iterations + sess.LatencyScore = CalculateObjectDetectionLatencyScore(took, outPixels) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + + monitor.AIRequestFinished(ctx, "object-detection", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return &res, nil +} + +func processObjectDetection(ctx context.Context, params aiRequestParams, req worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + detectionResp, ok := resp.(*worker.ObjectDetectionResponse) + if !ok { + return nil, errWrongFormat + } + + if detectionResp.Video.Url != "" { + var result []byte + var data bytes.Buffer + var name string + writer := bufio.NewWriter(&data) + err = worker.ReadVideoB64DataUrl(detectionResp.Video.Url, writer) + if err == nil { + // orchestrator sent base64 encoded result in .Url + name = string(core.RandomManifestID()) + ".mp4" + writer.Flush() + result = data.Bytes() + } else { + // orchestrator sent download url, get the data + + name = filepath.Base(detectionResp.Video.Url) + result, err = core.DownloadData(ctx, detectionResp.Video.Url) + if err != nil { + return nil, err + } + } + + newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(result), nil, 0) + if err != nil { + return nil, fmt.Errorf("error saving video to objectStore: %w", err) + } + + detectionResp.Video.Url = newUrl + } + + return detectionResp, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -1451,6 +1588,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitLiveVideoToVideo(ctx, params, sess, v) } + case worker.GenObjectDetectionMultipartRequestBody: + cap = core.Capability_ObjectDetection + modelID = defaultObjectDetectionModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitObjectDetection(ctx, params, sess, v) + } default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/ai_worker.go b/server/ai_worker.go index 34dc722bf..401285cc2 100644 --- a/server/ai_worker.go +++ b/server/ai_worker.go @@ -314,6 +314,23 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify return n.TextToSpeech(ctx, req) } reqOk = true + case "object-detection": + var req worker.GenObjectDetectionMultipartRequestBody + err = json.Unmarshal(reqData.Request, &req) + if err != nil || req.ModelId == nil { + break + } + input, err = core.DownloadData(ctx, reqData.InputUrl) + if err != nil { + break + } + modelID = *req.ModelId + resultType = "video/mp4" + req.Video.InitFromBytes(input, "video") + processFn = func(ctx context.Context) (interface{}, error) { + return n.ObjectDetection(ctx, req) + } + reqOk = true default: err = errors.New("AI request pipeline type not supported") } @@ -452,6 +469,34 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify } io.Copy(fw, &resBuf) resBuf.Reset() + case *worker.ObjectDetectionResponse: + //annotated video is optional + if wkrResp.Video.Url != "" { + err := worker.ReadVideoB64DataUrl(wkrResp.Video.Url, &resBuf) + + if err != nil { + clog.Errorf(ctx, "AI Worker failed to save image from data url err=%q", err) + sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, &body, err) + return + } + length = resBuf.Len() + wkrResp.Video.Url = fmt.Sprintf("%v.mp4", core.RandomManifestID()) // update json response to track filename attached + // create the part + w.SetBoundary(boundary) + hdrs := textproto.MIMEHeader{ + "Content-Type": {resultType}, + "Content-Length": {strconv.Itoa(length)}, + "Content-Disposition": {"attachment; filename=" + wkrResp.Video.Url}, + } + fw, err := w.CreatePart(hdrs) + if err != nil { + clog.Errorf(ctx, "Could not create multipart part err=%q", err) + sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, nil, err) + return + } + io.Copy(fw, &resBuf) + resBuf.Reset() + } } // add the json to the response diff --git a/server/ai_worker_test.go b/server/ai_worker_test.go index ab31a3e71..e681ede7b 100644 --- a/server/ai_worker_test.go +++ b/server/ai_worker_test.go @@ -140,6 +140,17 @@ func TestRunAIJob(t *testing.T) { } w.Write(imgData) return + } else if r.URL.Path == "/video.mp4" { + data, err := os.ReadFile("../test/ai/video") + if err != nil { + t.Fatalf("failed to read test video: %v", err) + } + vidData, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + t.Fatalf("failed to decode base64 test video: %v", err) + } + w.Write(vidData) + return } })) defer ts.Close() @@ -218,16 +229,23 @@ func TestRunAIJob(t *testing.T) { expectedErr: "", expectedOutputs: 1, }, + { + name: "ObjectDetection_Success", + notify: createAIJob(10, "object-detection", modelId, parsedURL.String()+"/video.mp4"), + pipeline: "object-detection", + expectedErr: "", + expectedOutputs: 1, + }, { name: "UnsupportedPipeline", - notify: createAIJob(10, "unsupported-pipeline", modelId, ""), + notify: createAIJob(11, "unsupported-pipeline", modelId, ""), pipeline: "unsupported-pipeline", expectedErr: "AI request validation failed for", expectedOutputs: 0, }, { name: "InvalidRequestData", - notify: createAIJob(11, "text-to-image-invalid", modelId, ""), + notify: createAIJob(12, "text-to-image-invalid", modelId, ""), pipeline: "text-to-image", expectedErr: "AI request validation failed for", expectedOutputs: 0, @@ -344,6 +362,13 @@ func TestRunAIJob(t *testing.T) { var respFile bytes.Buffer worker.ReadAudioB64DataUrl(expectedResp.Audio.Url, &respFile) assert.Equal(len(results.Files[audResp.Audio.Url]), respFile.Len()) + case "object-detection": + vidResp, ok := results.Results.(worker.ObjectDetectionResponse) + assert.True(ok) + assert.Equal("10", headers.Get("TaskId")) + assert.Equal(len(results.Files), 1) + expectedResp, _ := wkr.ObjectDetection(context.Background(), worker.GenObjectDetectionMultipartRequestBody{}) + assert.Equal(expectedResp.Frames[0][0].Seed, vidResp.Frames[0][0].Seed) } } }) @@ -380,6 +405,9 @@ func createAIJob(taskId int64, pipeline, modelId, inputUrl string) *net.NotifyAI desc := "a young adult" text := "let me tell you a story" req = worker.GenTextToSpeechJSONRequestBody{Description: &desc, ModelId: &modelId, Text: &text} + case "object-detection": + inputFile.InitFromBytes(nil, inputUrl) + req = worker.GenObjectDetectionMultipartRequestBody{ModelId: &modelId, Video: inputFile} case "unsupported-pipeline": req = worker.GenTextToImageJSONRequestBody{Prompt: "test prompt", ModelId: &modelId} case "text-to-image-invalid": @@ -635,6 +663,29 @@ func (a *stubAIWorker) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV } } +func (a *stubAIWorker) ObjectDetection(ctx context.Context, req worker.GenObjectDetectionMultipartRequestBody) (*worker.ObjectDetectionResponse, error) { + a.Called++ + if a.Err != nil { + return nil, a.Err + } else { + return &worker.ObjectDetectionResponse{ + Frames: [][]worker.Media{ + { + { + Url: "data:video/mp4;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=", + Nsfw: false, + Seed: 113, + }, + }, + }, + ConfidenceScores: "[[0.952, 0.948, ...], [0.961, 0.952, ...], [0.965, 0.96, ...], ...]", + Labels: "[['person', 'person', ...], ['person', 'person', ...], ['person', 'person', ...], ...]", + DetectionBoxes: "[[[0.14, 0.38, 640.13, 476.21], [343.38, 24.28, 640.14, 371.5], ...],[[75.1, 80.5, 320.7, 420.6], [60.4, 190.3, 370.2, 460.1], ...],[[50.5, 60.2, 200.8, 300.9], [100.3, 120.1, 350.5, 400.4], ...], ...]", + FramesPts: "[[0.03336666666666667], [0.06673333333333334], [0.1001], ...]" + }, nil + } +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { a.Called++ return nil diff --git a/server/rpc.go b/server/rpc.go index 7223c56a9..043a9bbc8 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -79,6 +79,7 @@ type Orchestrator interface { ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) + ObjectDetection(ctx context.Context, requestID string, req worker.GenObjectDetectionMultipartRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance diff --git a/server/rpc_test.go b/server/rpc_test.go index 43ec1a304..956fe4c05 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -227,6 +227,9 @@ func (r *stubOrchestrator) LiveVideoToVideo(ctx context.Context, requestID strin return nil, nil } +func (r *stubOrchestrator) ObjectDetection(ctx context.Context, requestID string, req worker.GenObjectDetectionMultipartRequestBody) (interface{}, error) { + return nil, nil +} func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true } @@ -1432,6 +1435,9 @@ func (r *mockOrchestrator) TextToSpeech(ctx context.Context, requestID string, r func (r *mockOrchestrator) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) { return nil, nil } +func (r *mockOrchestrator) ObjectDetection(ctx context.Context, requestID string, req worker.GenObjectDetectionMultipartRequestBody) (interface{}, error) { + return nil, nil +} func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true }