From ae3df6f06d72785ad20698372143673ae32e5bcd Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 13 Jun 2024 14:15:56 +0200 Subject: [PATCH] [server,core]: fix RegisterAIWorker interface implementaiton --- core/ai.go | 8 ++++++++ server/ot_rpc.go | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/core/ai.go b/core/ai.go index edd4520c10..192f7ca7d5 100644 --- a/core/ai.go +++ b/core/ai.go @@ -27,6 +27,7 @@ type AI interface { } type RemoteAIResultChan chan *RemoteAIWorkerResult + type RemoteAIWorkerManager struct { // TODO Mapping by pipeline remoteWorkers []*RemoteAIWorker @@ -86,6 +87,9 @@ func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.Text taskID, taskChan := m.addTaskChan() defer m.removeTaskChan(taskID) + // select a remote worker + w := m.remoteWorkers[0] + // send request to remote worker jsonData, err := json.Marshal(req) if err != nil { @@ -99,6 +103,10 @@ func (m *RemoteAIWorkerManager) TextToImage(ctx context.Context, req worker.Text } m.handleAIRequest(remoteReq) // task id, pipeline + if err := w.stream.Send(remoteReq); err != nil { + return nil, err + } + select { case <-ctx.Done(): // return EOF signal diff --git a/server/ot_rpc.go b/server/ot_rpc.go index 4af17b5f19..2da3a631d5 100644 --- a/server/ot_rpc.go +++ b/server/ot_rpc.go @@ -368,7 +368,7 @@ func (h *lphttp) RegisterTranscoder(req *net.RegisterRequest, stream net.Transco return nil } -func (h *lphttp) RegisterRemoteAIWorker(req *net.RegisterRequest, stream net.Transcoder_RegisterAIWorkerServer) error { +func (h *lphttp) RegisterAIWorker(req *net.RegisterRequest, stream net.Transcoder_RegisterAIWorkerServer) error { from := common.GetConnectionAddr(stream.Context()) glog.Infof("Got a RegisterAIWorker request from transcoder=%s capacity=%d", from, req.Capacity)