Skip to content

Commit

Permalink
ai: map remote workers by pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenr0d committed Jul 6, 2024
1 parent 96f7091 commit 28fe246
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
28 changes: 14 additions & 14 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type AI interface {
type RemoteAIResultChan chan *RemoteAIWorkerResult

type RemoteAIWorkerManager struct {
// TODO: account for pipeline
remoteWorkers map[string][]*RemoteAIWorker
// workers mapped by Pipeline(Capability) + ModelID
remoteWorkers map[Capability]map[string][]*RemoteAIWorker
liveWorkers map[net.Transcoder_RegisterAIWorkerServer]*RemoteAIWorker
workersMutex sync.Mutex

Expand All @@ -48,7 +48,7 @@ type RemoteAIWorkerResult struct {

func NewRemoteAIWorkerManager() *RemoteAIWorkerManager {
return &RemoteAIWorkerManager{
remoteWorkers: map[string][]*RemoteAIWorker{},
remoteWorkers: map[Capability]map[string][]*RemoteAIWorker{},
liveWorkers: map[net.Transcoder_RegisterAIWorkerServer]*RemoteAIWorker{},
workersMutex: sync.Mutex{},

Expand All @@ -70,9 +70,9 @@ func (m *RemoteAIWorkerManager) Manage(stream net.Transcoder_RegisterAIWorkerSer
}()

m.workersMutex.Lock()
for _, constraints := range capabilities.Constraints {
for cap, constraints := range capabilities.Constraints {
for modelID, _ := range constraints.Models {
m.remoteWorkers[modelID] = append(m.remoteWorkers[modelID], worker)
m.remoteWorkers[Capability(cap)][modelID] = append(m.remoteWorkers[Capability(cap)][modelID], worker)
}
}
m.liveWorkers[stream] = worker
Expand All @@ -95,16 +95,16 @@ func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.Text
taskID, taskChan := m.addTaskChan()
defer m.removeTaskChan(taskID)

var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_TextToImage][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_TextToImage][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_TextToImage][*req.ModelId] = append(m.remoteWorkers[Capability_TextToImage][*req.ModelId][1:], m.remoteWorkers[Capability_TextToImage][*req.ModelId][0])
}

// send request to remote worker
Expand Down Expand Up @@ -144,16 +144,16 @@ func (m *RemoteAIWorkerManager) ImageToImage(ctx context.Context, req worker.Ima
taskID, taskChan := m.addTaskChan()
defer m.removeTaskChan(taskID)

var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_ImageToImage][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_ImageToImage][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_ImageToImage][*req.ModelId] = append(m.remoteWorkers[Capability_ImageToImage][*req.ModelId][1:], m.remoteWorkers[Capability_ImageToImage][*req.ModelId][0])
}

// send request to remote worker
Expand Down Expand Up @@ -198,16 +198,16 @@ func (m *RemoteAIWorkerManager) Upscale(ctx context.Context, req worker.UpscaleM
defer m.removeTaskChan(taskID)

// select a remote worker
var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_Upscale][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_Upscale][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_Upscale][*req.ModelId] = append(m.remoteWorkers[Capability_Upscale][*req.ModelId][1:], m.remoteWorkers[Capability_Upscale][*req.ModelId][0])
}

// send request to remote worker
Expand Down
15 changes: 14 additions & 1 deletion core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,20 @@ func (orch *orchestrator) CheckCapacity(mid ManifestID) error {

// CheckAICapacity verifies if the orchestrator can process a request for a specific pipeline and modelID.
func (orch *orchestrator) CheckAICapacity(pipeline, modelID string) bool {
return len(orch.node.AIManager.remoteWorkers[modelID]) > 0
// TODO: Pass cap instead? Considering it's a public function might be
// better to pass the string directly
var cap Capability
switch pipeline {
case "text-to-image":
cap = Capability_TextToImage
case "image-to-image":
cap = Capability_ImageToImage
case "upscale":
cap = Capability_Upscale
default:
return false
}
return len(orch.node.AIManager.remoteWorkers[cap][modelID]) > 0
}

func (orch *orchestrator) TranscodeSeg(ctx context.Context, md *SegTranscodingMetadata, seg *stream.HLSSegment) (*TranscodeResult, error) {
Expand Down

0 comments on commit 28fe246

Please sign in to comment.