Skip to content

Commit

Permalink
feat: rotate workers until timeout, success, or all fail
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Jul 27, 2024
1 parent 605e821 commit af8925b
Showing 1 changed file with 62 additions and 32 deletions.
94 changes: 62 additions & 32 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import (
"strconv"
"strings"
"sync"
"time"

"github.com/golang/glog"
"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/common"
"github.com/livepeer/go-livepeer/net"
)

// TODO: seperate timeout for warm requests
const workerTimeout = 60 * time.Second // Adjust this value as needed

type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
Expand Down Expand Up @@ -88,7 +92,7 @@ func (m *RemoteAIWorkerManager) Manage(stream net.Transcoder_RegisterAIWorkerSer
m.workersMutex.Unlock()

<-worker.eof
glog.Infof("Remote AI worker=%s done, removing from live AI workers map", from)
glog.Infof("Remote AI worker stream closed, removing from live AI workers map worker=%s", from)

m.workersMutex.Lock()
delete(m.liveWorkers, stream)
Expand Down Expand Up @@ -123,11 +127,6 @@ func (m *RemoteAIWorkerManager) processAIRequest(ctx context.Context, capability

modelID := getModelID(req)

w, err := m.selectWorker(capability, modelID)
if err != nil {
return nil, err
}

jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
Expand All @@ -139,35 +138,21 @@ func (m *RemoteAIWorkerManager) processAIRequest(ctx context.Context, capability
Data: jsonData,
}

if err := w.stream.Send(remoteReq); err != nil {
return nil, err
}

select {
case <-ctx.Done():
return nil, ctx.Err()
case chanData := <-taskChan:
if chanData.Err != "" {
return nil, fmt.Errorf("%v", chanData.Err)
workerCount := m.getWorkerCount(capability, modelID)
for i := 0; i < workerCount; i++ {
w, err := m.selectWorker(capability, modelID)
if err != nil {
return nil, err
}
glog.Infof("Received AI result for task %d", chanData.TaskID)
var res interface{}
switch aiRequestType {
case net.AIRequestType_ImageToVideo:
var videoRes worker.VideoResponse
if err := json.Unmarshal(chanData.Bytes, &videoRes); err != nil {
return nil, err
}
res = &videoRes
default:
var imgRes worker.ImageResponse
if err := json.Unmarshal(chanData.Bytes, &imgRes); err != nil {
return nil, err
}
res = &imgRes

chanData, err := m.sendRequestToWorker(ctx, w, remoteReq, taskChan)
if err == nil {
return m.processWorkerResponse(chanData, aiRequestType)
}
return res, nil

glog.Warningf("Worker %s failed, retrying taskID=%v err=%v", w.addr, taskID, err)
}
return nil, ErrNoTranscodersAvailable
}

func (m *RemoteAIWorkerManager) selectWorker(capability Capability, modelID string) (*RemoteAIWorker, error) {
Expand All @@ -189,6 +174,51 @@ func (m *RemoteAIWorkerManager) selectWorker(capability Capability, modelID stri
return w, nil
}

func (m *RemoteAIWorkerManager) getWorkerCount(capability Capability, modelID string) int {
m.workersMutex.Lock()
defer m.workersMutex.Unlock()
return len(m.remoteWorkers[capability][modelID])
}

func (m *RemoteAIWorkerManager) sendRequestToWorker(ctx context.Context, w *RemoteAIWorker, remoteReq *net.NotifyAIJob, taskChan RemoteAIResultChan) (*RemoteAIWorkerResult, error) {
if err := w.stream.Send(remoteReq); err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

timeoutCtx, cancel := context.WithTimeout(ctx, workerTimeout)
defer cancel()

select {
case <-timeoutCtx.Done():
return nil, fmt.Errorf("worker timed out")
case chanData := <-taskChan:
if chanData.Err != "" {
return nil, fmt.Errorf("worker returned error: %s", chanData.Err)
}
return chanData, nil
}
}

func (m *RemoteAIWorkerManager) processWorkerResponse(chanData *RemoteAIWorkerResult, aiRequestType net.AIRequestType) (interface{}, error) {
glog.Infof("Received AI result for task %d", chanData.TaskID)
var res interface{}
switch aiRequestType {
case net.AIRequestType_ImageToVideo:
var videoRes worker.VideoResponse
if err := json.Unmarshal(chanData.Bytes, &videoRes); err != nil {
return nil, fmt.Errorf("failed to unmarshal video response: %w", err)
}
res = &videoRes
default:
var imgRes worker.ImageResponse
if err := json.Unmarshal(chanData.Bytes, &imgRes); err != nil {
return nil, fmt.Errorf("failed to unmarshal image response: %w", err)
}
res = &imgRes
}
return res, nil
}

func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
res, err := m.processAIRequest(ctx, Capability_TextToImage, req, net.AIRequestType_TextToImage)
if err != nil {
Expand Down

0 comments on commit af8925b

Please sign in to comment.