Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai: add models webhook url param #3209

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/livepeer/livepeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func parseLivepeerConfig() starter.LivepeerConfig {
cfg.AIModels = flag.String("aiModels", *cfg.AIModels, "Set models (pipeline:model_id) for AI worker to load upon initialization")
cfg.AIModelsDir = flag.String("aiModelsDir", *cfg.AIModelsDir, "Set directory where AI model weights are stored")
cfg.AIRunnerImage = flag.String("aiRunnerImage", *cfg.AIRunnerImage, "Set the docker image for the AI runner: Example - livepeer/ai-runner:0.0.1")
cfg.AIModelsWebhookUrl = flag.String("aiModelsWebhookUrl", *cfg.AIModelsWebhookUrl, "URL for the AI models webhook or models config file path: Example - <protocol>://<host>/<path> or file://<path>")

// Onchain:
cfg.EthAcctAddr = flag.String("ethAcctAddr", *cfg.EthAcctAddr, "Existing Eth account address. For use when multiple ETH accounts exist in the keystore directory")
Expand Down
41 changes: 31 additions & 10 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type LivepeerConfig struct {
OrchSecret *string
TranscodingOptions *string
AIModels *string
AIModelsWebhookUrl *string
MaxAttempts *int
SelectRandWeight *float64
SelectStakeWeight *float64
Expand Down Expand Up @@ -202,6 +203,7 @@ func DefaultLivepeerConfig() LivepeerConfig {
defaultAIModels := ""
defaultAIModelsDir := ""
defaultAIRunnerImage := "livepeer/ai-runner:latest"
defaultAIModelsWebhookUrl := ""

// Onchain:
defaultEthAcctAddr := ""
Expand Down Expand Up @@ -297,10 +299,11 @@ func DefaultLivepeerConfig() LivepeerConfig {
TestTranscoder: &defaultTestTranscoder,

// AI:
AIWorker: &defaultAIWorker,
AIModels: &defaultAIModels,
AIModelsDir: &defaultAIModelsDir,
AIRunnerImage: &defaultAIRunnerImage,
AIWorker: &defaultAIWorker,
AIModels: &defaultAIModels,
AIModelsDir: &defaultAIModelsDir,
AIRunnerImage: &defaultAIRunnerImage,
AIModelsWebhookUrl: &defaultAIModelsWebhookUrl,

// Onchain:
EthAcctAddr: &defaultEthAcctAddr,
Expand Down Expand Up @@ -1191,11 +1194,29 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
currencyBase = currency
}

if *cfg.AIModels != "" {
configs, err := core.ParseAIModelConfigs(*cfg.AIModels)
if err != nil {
glog.Errorf("Error parsing -aiModels: %v", err)
return
if *cfg.AIModels != "" && *cfg.AIModelsWebhookUrl != "" {
pwilczynskiclearcode marked this conversation as resolved.
Show resolved Hide resolved
glog.Error("Both '-aiModels' and '-aiModelsWebhookUrl' flags are set. Please specify only one of them.")
return
} else if *cfg.AIModelsWebhookUrl != "" || *cfg.AIModels != "" {
var configs []core.AIModelConfig
var err error

if *cfg.AIModelsWebhookUrl != "" {
webhook, err := core.NewAIModelWebhook(*cfg.AIModelsWebhookUrl, 1*time.Minute)
if err != nil {
glog.Errorf("Error creating AI model webhook: %v", err)
return
}
configs = webhook.GetConfigs()
n.ModelsWebhook = webhook
}

if *cfg.AIModels != "" {
configs, err = core.ParseAIModelConfigs(*cfg.AIModels)
if err != nil {
glog.Errorf("Error parsing -aiModels: %v", err)
return
}
}

for _, config := range configs {
Copy link
Contributor

@pwilczynskiclearcode pwilczynskiclearcode Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gioelecerati This entire loop needs to be run each time there are some changes to the ModelsWebhook.configs

Expand Down Expand Up @@ -1361,7 +1382,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}
}
} else {
glog.Error("The '-aiModels' flag was set, but no model configuration was provided. Please specify the model configuration using the '-aiModels' flag.")
glog.Error("The '-aiModels' or '-aiModelsWebhookUrl' flag was set, but no model configuration was provided. Please specify the model configuration using either the '-aiModels' or the '-aiModelsWebhookUrl' flag.")
return
}

Expand Down
101 changes: 101 additions & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ package core

import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/golang/glog"
"github.com/livepeer/ai-worker/worker"
)

Expand Down Expand Up @@ -127,3 +133,98 @@ func ParseStepsFromModelID(modelID *string, defaultSteps float64) float64 {

return numInferenceSteps
}

type ModelsWebhook struct {
configs []AIModelConfig
source string
mu sync.RWMutex
lastHash string
refreshInt time.Duration
stopChan chan struct{}
}

func NewAIModelWebhook(source string, refreshInterval time.Duration) (*ModelsWebhook, error) {
webhook := &ModelsWebhook{
source: source,
refreshInt: refreshInterval,
stopChan: make(chan struct{}),
}
err := webhook.refreshConfigs()
if err != nil {
return nil, err
}
go webhook.startRefreshing()
return webhook, nil
}

func (w *ModelsWebhook) startRefreshing() {
ticker := time.NewTicker(w.refreshInt)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := w.refreshConfigs()
if err != nil {
glog.Errorf("Error refreshing AI model configs: %v", err)
}
case <-w.stopChan:
return
}
}
}

func (w *ModelsWebhook) refreshConfigs() error {
content, err := w.fetchContent()
if err != nil {
return err
}
hash := hashContent(content)
if hash != w.lastHash {
configs, err := ParseAIModelConfigs(content)
if err != nil {
return err
}
w.mu.Lock()
w.configs = configs
w.lastHash = hash
w.mu.Unlock()
glog.V(6).Info("AI Model configurations have been updated.")
}
return nil
}

func (w *ModelsWebhook) fetchContent() (string, error) {
if strings.HasPrefix(w.source, "file://") {
filePath := strings.TrimPrefix(w.source, "file://")
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return string(data), nil
} else {
resp, err := http.Get(w.source)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
}

func hashContent(content string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(content)))
}

func (w *ModelsWebhook) GetConfigs() []AIModelConfig {
w.mu.RLock()
defer w.mu.RUnlock()
return w.configs
}

func (w *ModelsWebhook) Stop() {
mjh1 marked this conversation as resolved.
Show resolved Hide resolved
close(w.stopChan)
}
109 changes: 109 additions & 0 deletions core/ai_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package core

import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand All @@ -18,3 +22,108 @@ func TestPipelineToCapability(t *testing.T) {
assert.Error(t, err)
assert.Equal(t, cap, Capability_Unused)
}

func TestModelsWebhook(t *testing.T) {
assert := assert.New(t)

mockResponses := []string{
`[{"name":"Model1", "model_id":"model1", "pipeline":"text-to-image", "warm":true}]`,
`[{"name":"Model2", "model_id":"model2", "pipeline":"image-to-image", "warm":false}]`,
}
responseIndex := 0
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockResponses[responseIndex]))
responseIndex = (responseIndex + 1) % len(mockResponses)
}))
defer mockServer.Close()

refreshInterval := 100 * time.Millisecond

webhook, err := NewAIModelWebhook(mockServer.URL, refreshInterval)
assert.NoError(err)
assert.NotNil(webhook)
defer webhook.Stop()

configs := webhook.GetConfigs()
assert.Len(configs, 1)
assert.Equal("model1", configs[0].ModelID)
assert.Equal("text-to-image", configs[0].Pipeline)
assert.True(configs[0].Warm)

time.Sleep(refreshInterval * 2)

// Check if models are updated
configs = webhook.GetConfigs()
assert.Len(configs, 1)
assert.Equal("model2", configs[0].ModelID)
assert.Equal("image-to-image", configs[0].Pipeline)
assert.False(configs[0].Warm)

_, err = NewAIModelWebhook("http://invalid-url", refreshInterval)
assert.Error(err)

invalidServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`invalid json`))
}))
defer invalidServer.Close()

invalidWebhook, err := NewAIModelWebhook(invalidServer.URL, refreshInterval)
assert.Error(err)
assert.Nil(invalidWebhook)
}

func TestModelsFile(t *testing.T) {
assert := assert.New(t)

tmpfile, err := os.CreateTemp("", "aimodels*.json")
assert.NoError(err)
defer os.Remove(tmpfile.Name())

initialContent := `[{"name":"Model1", "model_id":"model1", "pipeline":"text-to-image", "warm":true}]`
_, err = tmpfile.Write([]byte(initialContent))
assert.NoError(err)
tmpfile.Close()

refreshInterval := 100 * time.Millisecond

webhook, err := NewAIModelWebhook("file://"+tmpfile.Name(), refreshInterval)
assert.NoError(err)
assert.NotNil(webhook)
defer webhook.Stop()

configs := webhook.GetConfigs()
assert.Len(configs, 1)
assert.Equal("model1", configs[0].ModelID)
assert.Equal("text-to-image", configs[0].Pipeline)
assert.True(configs[0].Warm)

time.Sleep(refreshInterval / 2)
updatedContent := `[{"name":"Model2", "model_id":"model2", "pipeline":"image-to-image", "warm":false}]`
err = os.WriteFile(tmpfile.Name(), []byte(updatedContent), 0644)
assert.NoError(err)

time.Sleep(refreshInterval * 2)

// Check if models are updated
configs = webhook.GetConfigs()
assert.Len(configs, 1)
assert.Equal("model2", configs[0].ModelID)
assert.Equal("image-to-image", configs[0].Pipeline)
assert.False(configs[0].Warm)

_, err = NewAIModelWebhook("file:///nonexistent/path", refreshInterval)
assert.Error(err)

invalidFile, err := os.CreateTemp("", "invalid*.json")
assert.NoError(err)
defer os.Remove(invalidFile.Name())
_, err = invalidFile.Write([]byte(`invalid json`))
assert.NoError(err)
invalidFile.Close()

invalidWebhook, err := NewAIModelWebhook("file://"+invalidFile.Name(), refreshInterval)
assert.Error(err)
assert.Nil(invalidWebhook)
}
3 changes: 2 additions & 1 deletion core/livepeernode.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ type LivepeerNode struct {
Database *common.DB

// AI worker public fields
AIWorker AI
AIWorker AI
ModelsWebhook *ModelsWebhook
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need this field for now as it's not used right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the point of this to fetch updated model configs? I added it to be used where needed doing updatedModels = ModelsWebhook.GetConfigs() - but maybe @pwilczynskiclearcode knows better where this is needed to be used

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's assigned in here... and must live for the entire orchestrator lifetime and not be garbage collected or stopped.
I think it's needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha yep, it'll be integrated in a further PR


// Transcoder public fields
SegmentChans map[ManifestID]SegmentChan
Expand Down
Loading