diff --git a/core/ai.go b/core/ai.go index dd9eac1e82..06e8e5c4fe 100644 --- a/core/ai.go +++ b/core/ai.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/golang/glog" "github.com/livepeer/ai-worker/worker" @@ -16,6 +17,9 @@ import ( "github.com/livepeer/go-livepeer/net" ) +// TODO: seperate timeout for warm requests +const workerTimeout = 60 * time.Second // Adjust this value as needed + type AI interface { TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) @@ -88,7 +92,7 @@ func (m *RemoteAIWorkerManager) Manage(stream net.Transcoder_RegisterAIWorkerSer m.workersMutex.Unlock() <-worker.eof - glog.Infof("Remote AI worker=%s done, removing from live AI workers map", from) + glog.Infof("Remote AI worker stream closed, removing from live AI workers map worker=%s", from) m.workersMutex.Lock() delete(m.liveWorkers, stream) @@ -123,11 +127,6 @@ func (m *RemoteAIWorkerManager) processAIRequest(ctx context.Context, capability modelID := getModelID(req) - w, err := m.selectWorker(capability, modelID) - if err != nil { - return nil, err - } - jsonData, err := json.Marshal(req) if err != nil { return nil, err @@ -139,35 +138,21 @@ func (m *RemoteAIWorkerManager) processAIRequest(ctx context.Context, capability Data: jsonData, } - if err := w.stream.Send(remoteReq); err != nil { - return nil, err - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case chanData := <-taskChan: - if chanData.Err != "" { - return nil, fmt.Errorf("%v", chanData.Err) + workerCount := m.getWorkerCount(capability, modelID) + for i := 0; i < workerCount; i++ { + w, err := m.selectWorker(capability, modelID) + if err != nil { + return nil, err } - glog.Infof("Received AI result for task %d", chanData.TaskID) - var res interface{} - switch aiRequestType { - case net.AIRequestType_ImageToVideo: - var videoRes worker.VideoResponse - if err := json.Unmarshal(chanData.Bytes, &videoRes); err != nil { - return nil, err - } - res = &videoRes - default: - var imgRes worker.ImageResponse - if err := json.Unmarshal(chanData.Bytes, &imgRes); err != nil { - return nil, err - } - res = &imgRes + + chanData, err := m.sendRequestToWorker(ctx, w, remoteReq, taskChan) + if err == nil { + return m.processWorkerResponse(chanData, aiRequestType) } - return res, nil + + glog.Warningf("Worker %s failed, retrying taskID=%v err=%v", w.addr, taskID, err) } + return nil, ErrNoTranscodersAvailable } func (m *RemoteAIWorkerManager) selectWorker(capability Capability, modelID string) (*RemoteAIWorker, error) { @@ -189,6 +174,51 @@ func (m *RemoteAIWorkerManager) selectWorker(capability Capability, modelID stri return w, nil } +func (m *RemoteAIWorkerManager) getWorkerCount(capability Capability, modelID string) int { + m.workersMutex.Lock() + defer m.workersMutex.Unlock() + return len(m.remoteWorkers[capability][modelID]) +} + +func (m *RemoteAIWorkerManager) sendRequestToWorker(ctx context.Context, w *RemoteAIWorker, remoteReq *net.NotifyAIJob, taskChan RemoteAIResultChan) (*RemoteAIWorkerResult, error) { + if err := w.stream.Send(remoteReq); err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, workerTimeout) + defer cancel() + + select { + case <-timeoutCtx.Done(): + return nil, fmt.Errorf("worker timed out") + case chanData := <-taskChan: + if chanData.Err != "" { + return nil, fmt.Errorf("worker returned error: %s", chanData.Err) + } + return chanData, nil + } +} + +func (m *RemoteAIWorkerManager) processWorkerResponse(chanData *RemoteAIWorkerResult, aiRequestType net.AIRequestType) (interface{}, error) { + glog.Infof("Received AI result for task %d", chanData.TaskID) + var res interface{} + switch aiRequestType { + case net.AIRequestType_ImageToVideo: + var videoRes worker.VideoResponse + if err := json.Unmarshal(chanData.Bytes, &videoRes); err != nil { + return nil, fmt.Errorf("failed to unmarshal video response: %w", err) + } + res = &videoRes + default: + var imgRes worker.ImageResponse + if err := json.Unmarshal(chanData.Bytes, &imgRes); err != nil { + return nil, fmt.Errorf("failed to unmarshal image response: %w", err) + } + res = &imgRes + } + return res, nil +} + func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { res, err := m.processAIRequest(ctx, Capability_TextToImage, req, net.AIRequestType_TextToImage) if err != nil {