Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement apiToken failover mechanism #1256

Merged
merged 36 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
095b25e
feat: implement apiToken failover mechanism
cr7258 Aug 27, 2024
4af200c
Use SetSharedData for leader election and syncing apiTokens between W…
cr7258 Aug 31, 2024
192d855
Merge branch 'main' into failover
cr7258 Sep 1, 2024
856343c
support failover for all models
cr7258 Sep 1, 2024
7d5f427
add cas retry logic
cr7258 Sep 7, 2024
ee49848
wrap getApiTokenInUse funtion
cr7258 Sep 7, 2024
1e40d82
only removed the apiToken when the number of consecutive request fail…
cr7258 Sep 25, 2024
432395b
use uuid as vmid
cr7258 Sep 25, 2024
67551f2
fix byte covert
cr7258 Sep 26, 2024
82b2284
reset shared data during initialization
cr7258 Sep 26, 2024
daa48fe
Merge branch 'main' into failover
cr7258 Sep 26, 2024
8a818ed
failover support new model
cr7258 Sep 26, 2024
0554c85
fix
cr7258 Sep 26, 2024
e3401d5
move SetApiTokensFailover to complete function
cr7258 Sep 28, 2024
0f79913
wrap failover logic into ProviderConfig
cr7258 Sep 28, 2024
bda87f1
fix
cr7258 Sep 28, 2024
263c38c
config envoy local cluster and isolate apiToken ctx between different…
cr7258 Oct 5, 2024
374d5be
update README.md
cr7258 Oct 7, 2024
fd49f2d
add description
cr7258 Oct 7, 2024
66c371b
fix nil point exception when don't set failover config
cr7258 Oct 7, 2024
2130c00
Merge branch 'main' into failover
cr7258 Oct 7, 2024
7f36c09
support github provider
cr7258 Oct 7, 2024
01b92d8
fix
cr7258 Oct 10, 2024
a11a38b
Merge branch 'main' into failover
cr7258 Oct 10, 2024
01b0eec
unified the transformation of HTTP headers and body for ai-proxy and …
cr7258 Oct 17, 2024
a180e65
fix readme
cr7258 Oct 17, 2024
a72a8a1
optimize
cr7258 Oct 17, 2024
6a62333
refine transform headers and body
cr7258 Nov 3, 2024
f1f375e
move defaultInsertHttpContextMessage to context.go
cr7258 Nov 3, 2024
0296110
fix
cr7258 Nov 5, 2024
f164854
remove get context in original protocol
cr7258 Nov 5, 2024
0938f98
add reset apiToken log
cr7258 Nov 6, 2024
8e425c3
add GetApiName to determine apiName for original protocol
cr7258 Nov 14, 2024
51c766f
fix
cr7258 Nov 14, 2024
0b3422a
make GetApiName optional
cr7258 Nov 14, 2024
f0f24cc
Merge branch 'main' into failover
CH3CHO Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions plugins/wasm-go/extensions/ai-proxy/config/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package config

import (
"github.com/tidwall/gjson"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/tidwall/gjson"
)

// @Name ai-proxy
Expand Down
26 changes: 23 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ func main() {
}

func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
// log.Debugf("loading config: %s", json.String())

pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
return err
}

providerConfig := pluginConfig.GetProviderConfig()
providerConfig.SetApiTokensFailover(log)
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved

return nil
}

Expand All @@ -72,8 +74,18 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)

if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()

log.Debugf("ApiTokens: %s, UnavailableApiTokens: %s", provider.ApiTokens, provider.UnavailableApiTokens)
apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check")
if apiTokenHealthCheck != "" {
ctx.SetContext(provider.ApiTokenInUse, apiTokenHealthCheck)
} else {
providerConfig := pluginConfig.GetProviderConfig()
ctx.SetContext(provider.ApiTokenInUse, providerConfig.GetRandomToken())
}

hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
Expand All @@ -85,6 +97,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}

_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
Expand Down Expand Up @@ -145,6 +158,13 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()

if ctx.GetContext(provider.ApiTokenHealthCheck) == nil {
unavailableApiToken := ctx.GetContext(provider.ApiTokenInUse).(string)
providerConfig := pluginConfig.GetProviderConfig()
providerConfig.HandleUnavailableApiToken(unavailableApiToken, log)
}

return types.ActionContinue
}

Expand Down
154 changes: 154 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package provider

import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 的失败阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
// @Title zh-CN 健康检测的间隔时间,单位毫秒
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
}

var (
ApiTokens []string
ApiTokenRequestFailureCount = make(map[string]int64)
ApiTokenRequestSuccessCount = make(map[string]int64)
healthCheckClient wrapper.HttpClient
UnavailableApiTokens []string
)

const (
ApiTokenInUse = "apiTokenInUse"
ApiTokenHealthCheck = "apiTokenHealthCheck"
)

func (f *failover) FromJson(json gjson.Result) {
f.enabled = json.Get("enabled").Bool()
f.failureThreshold = json.Get("failureThreshold").Int()
if f.failureThreshold == 0 {
f.failureThreshold = 3
}
f.successThreshold = json.Get("successThreshold").Int()
if f.successThreshold == 0 {
f.successThreshold = 1
}
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
if f.healthCheckInterval == 0 {
f.healthCheckInterval = 5000
}
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
if f.healthCheckTimeout == 0 {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
}

func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
}
return nil
}

func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) {
ApiTokens = c.apiTokens
// TODO: 目前需要手动加一个 cluster 指向本地的地址,健康检测需要访问该地址
healthCheckClient = wrapper.NewClusterClient(wrapper.StaticIpCluster{
ServiceName: "local_cluster",
Port: 10000,
})
johnlanni marked this conversation as resolved.
Show resolved Hide resolved

if c.failover != nil && c.failover.enabled {
wrapper.RegisteTickFunc(c.failover.healthCheckTimeout, func() {
if len(UnavailableApiTokens) > 0 {
for _, apiToken := range UnavailableApiTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(UnavailableApiTokens, ", "))

path := "/v1/chat/completions"
headers := [][2]string{
{"Content-Type", "application/json"},
{"ApiToken-Health-Check", apiToken},
}
body := []byte(fmt.Sprintf(`{
"model": "%s",
"messages": [
{
"role": "user",
"content": "who are you?"
}
]
}`, c.failover.healthCheckModel))
err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.HandleAvailableApiToken(apiToken, log)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
})
}
}

func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) {
ApiTokenRequestSuccessCount[apiToken]++
if ApiTokenRequestSuccessCount[apiToken] >= c.failover.successThreshold {
log.Infof("apiToken %s is available now, add it back to the list", apiToken)
c.RemoveToken(&UnavailableApiTokens, apiToken)
c.AddToken(&ApiTokens, apiToken)
ApiTokenRequestSuccessCount[apiToken] = 0
}
}

func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) {
ApiTokenRequestFailureCount[apiToken]++
if ApiTokenRequestFailureCount[apiToken] >= c.failover.failureThreshold {
log.Errorf("Remove unavailable apiToken from list: %s", apiToken)
c.RemoveToken(&ApiTokens, apiToken)
c.AddToken(&UnavailableApiTokens, apiToken)
ApiTokenRequestFailureCount[apiToken] = 0
}
}

func (c *ProviderConfig) RemoveToken(tokens *[]string, apiToken string) {
tmp := make([]string, 0)
for _, v := range *tokens {
if v != apiToken {
tmp = append(tmp, v)
}
}
*tokens = tmp
}

func (c *ProviderConfig) AddToken(tokens *[]string, apiToken string) {
if !contains(*tokens, apiToken) {
*tokens = append(*tokens, apiToken)
}
}

func contains(slice []string, element string) bool {
for _, v := range slice {
if v == element {
return true
}
}
return false
}
22 changes: 18 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
Expand Down Expand Up @@ -262,6 +265,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
}

failoverJson := json.Get("failover")
if failoverJson.Exists() {
c.failover = &failover{}
c.failover.FromJson(failoverJson)
}
}

func (c *ProviderConfig) Validate() error {
Expand All @@ -277,6 +286,12 @@ func (c *ProviderConfig) Validate() error {
}
}

if c.failover != nil {
if err := c.failover.Validate(); err != nil {
return err
}
}

if c.typ == "" {
return errors.New("missing type in provider config")
}
Expand All @@ -300,15 +315,14 @@ func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) strin
}

func (c *ProviderConfig) GetRandomToken() string {
apiTokens := c.apiTokens
count := len(apiTokens)
count := len(ApiTokens)
switch count {
case 0:
return ""
case 1:
return apiTokens[0]
return ApiTokens[0]
default:
return apiTokens[rand.Intn(count)]
return ApiTokens[rand.Intn(count)]
}
}

Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(qwenDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string))

if m.config.protocol == protocolOriginal {
return types.ActionContinue, nil
Expand Down
Loading