Skip to content

Commit

Permalink
ai: map remote workers by pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenr0d committed Jul 6, 2024
1 parent 96f7091 commit b618c85
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type AI interface {
type RemoteAIResultChan chan *RemoteAIWorkerResult

type RemoteAIWorkerManager struct {
// TODO: account for pipeline
remoteWorkers map[string][]*RemoteAIWorker
// workers mapped by Pipeline(Capability) + ModelID
remoteWorkers map[Capability]map[string][]*RemoteAIWorker
liveWorkers map[net.Transcoder_RegisterAIWorkerServer]*RemoteAIWorker
workersMutex sync.Mutex

Expand All @@ -48,7 +48,7 @@ type RemoteAIWorkerResult struct {

func NewRemoteAIWorkerManager() *RemoteAIWorkerManager {
return &RemoteAIWorkerManager{
remoteWorkers: map[string][]*RemoteAIWorker{},
remoteWorkers: map[Capability]map[string][]*RemoteAIWorker{},
liveWorkers: map[net.Transcoder_RegisterAIWorkerServer]*RemoteAIWorker{},
workersMutex: sync.Mutex{},

Expand All @@ -70,9 +70,9 @@ func (m *RemoteAIWorkerManager) Manage(stream net.Transcoder_RegisterAIWorkerSer
}()

m.workersMutex.Lock()
for _, constraints := range capabilities.Constraints {
for cap, constraints := range capabilities.Constraints {
for modelID, _ := range constraints.Models {
m.remoteWorkers[modelID] = append(m.remoteWorkers[modelID], worker)
m.remoteWorkers[Capability(cap)][modelID] = append(m.remoteWorkers[Capability(cap)][modelID], worker)
}
}
m.liveWorkers[stream] = worker
Expand All @@ -95,16 +95,16 @@ func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.Text
taskID, taskChan := m.addTaskChan()
defer m.removeTaskChan(taskID)

var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_TextToImage][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_TextToImage][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_TextToImage][*req.ModelId] = append(m.remoteWorkers[Capability_TextToImage][*req.ModelId][1:], m.remoteWorkers[Capability_TextToImage][*req.ModelId][0])
}

// send request to remote worker
Expand Down Expand Up @@ -144,16 +144,16 @@ func (m *RemoteAIWorkerManager) ImageToImage(ctx context.Context, req worker.Ima
taskID, taskChan := m.addTaskChan()
defer m.removeTaskChan(taskID)

var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_ImageToImage][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_ImageToImage][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_ImageToImage][*req.ModelId] = append(m.remoteWorkers[Capability_ImageToImage][*req.ModelId][1:], m.remoteWorkers[Capability_ImageToImage][*req.ModelId][0])
}

// send request to remote worker
Expand Down Expand Up @@ -198,16 +198,16 @@ func (m *RemoteAIWorkerManager) Upscale(ctx context.Context, req worker.UpscaleM
defer m.removeTaskChan(taskID)

// select a remote worker
var workerCount = len(m.remoteWorkers[*req.ModelId])
var workerCount = len(m.remoteWorkers[Capability_Upscale][*req.ModelId])
if workerCount == 0 {
return nil, ErrOrchCap
}

// select a remote worker
w := m.remoteWorkers[*req.ModelId][0]
w := m.remoteWorkers[Capability_Upscale][*req.ModelId][0]
glog.Infof("Selected worker %s for model %s; Total worker count: %v", w.addr, *req.ModelId, workerCount)
if workerCount > 1 {
m.remoteWorkers[*req.ModelId] = append(m.remoteWorkers[*req.ModelId][1:], m.remoteWorkers[*req.ModelId][0])
m.remoteWorkers[Capability_Upscale][*req.ModelId] = append(m.remoteWorkers[Capability_Upscale][*req.ModelId][1:], m.remoteWorkers[Capability_Upscale][*req.ModelId][0])
}

// send request to remote worker
Expand Down

0 comments on commit b618c85

Please sign in to comment.