Skip to content

Commit

Permalink
Merge branch 'ai-video' into parler_tts
Browse files Browse the repository at this point in the history
  • Loading branch information
pschroedl authored Oct 31, 2024
2 parents 782c77f + 7bb0026 commit baa6ed0
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
4 changes: 3 additions & 1 deletion core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ const (
Capability_SegmentAnything2 Capability = 32
Capability_LLM Capability = 33
Capability_ImageToText Capability = 34
Capability_TextToSpeech Capability = 35
Capability_LiveVideoToVideo Capability = 35
Capability_TextToSpeech Capability = 36
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -121,6 +122,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_SegmentAnything2: "Segment anything 2",
Capability_LLM: "Llm",
Capability_ImageToText: "Image to text",
Capability_LiveVideoToVideo: "Live video to video",
Capability_TextToSpeech: "Text to speech",
}

Expand Down
30 changes: 30 additions & 0 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(lp.ImageToText()))
lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo()))
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(lp.TextToSpeech()))
// Additionally, there is the '/aiResults' endpoint registered in server/rpc.go
return nil
Expand Down Expand Up @@ -236,6 +237,35 @@ func (h *lphttp) ImageToText() http.Handler {
}

handleAIRequest(ctx, w, r, orch, req)

})
}

func (h *lphttp) StartLiveVideoToVideo() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// skipping handleAIRequest for now until we have payments

var (
mid = string(core.RandomManifestID())
pubUrl = "/ai/live-video/" + mid
subUrl = pubUrl + "/out"
)
jsonData, err := json.Marshal(struct {
PublishUrl string
SubscribeUrl string
}{
PublishUrl: pubUrl,
SubscribeUrl: subUrl,
})
if err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(jsonData)
})
}

Expand Down
24 changes: 24 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText)))
ls.HTTPMux.Handle("/text-to-speech", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToSpeechJSONRequestBody], processTextToSpeech)))

// This is called by the media server when the stream is ready
ls.HTTPMux.Handle("/live/video-to-video/start", ls.StartLiveVideo())

return nil
}

Expand Down Expand Up @@ -362,3 +365,24 @@ func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) StartLiveVideo() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
streamName := r.FormValue("stream")
if streamName == "" {
http.Error(w, "Missing stream name", http.StatusBadRequest)
return
}
requestID := string(core.RandomManifestID())
params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}
ctx := clog.AddVal(r.Context(), "request_id", requestID)
// TODO set model and initial parameters here if necessary (eg, prompt)
req := struct{}{}
resp, err := processAIRequest(ctx, params, req)
clog.Infof(ctx, "Received live video AI request stream=%s resp=%v err=%v", streamName, resp, err)
})
}
26 changes: 26 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const defaultAudioToTextModelID = "openai/whisper-large-v3"
const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"
const defaultImageToTextModelID = "Salesforce/blip-image-captioning-large"
const defaultLiveVideoToVideoModelID = "cumulo-autumn/stream-diffusion"
const defaultTextToSpeechModelID = "parler-tts/parler-tts-large-v1"

var errWrongFormat = fmt.Errorf("result not in correct format")
Expand Down Expand Up @@ -985,6 +986,19 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess
return &res, nil
}

func submitLiveVideoToVideo(ctx context.Context, params aiRequestParams, sess *AISession, req struct{ ModelId *string }) (any, error) {
//client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
var err error
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "LiveVideoToVideo", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
// TODO check urls and add sess.Transcoder to the host if necessary
return nil, nil
}

func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
if tokensUsed <= 0 {
return 0
Expand Down Expand Up @@ -1333,6 +1347,18 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitTextToSpeech(ctx, params, sess, v)
}
/*
case worker.StartLiveVideoToVideoFormdataRequestBody:
cap = core.Capability_LiveVideoToVideo
modelID = defaultLiveVideoToVideoModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLiveVideoToVideo(ctx, params, sess, v)
}
*/

default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down

0 comments on commit baa6ed0

Please sign in to comment.