From 095b25edee108e537808da526783934e18c6ce60 Mon Sep 17 00:00:00 2001 From: Se7en Date: Tue, 27 Aug 2024 11:18:36 +0800 Subject: [PATCH 01/31] feat: implement apiToken failover mechanism --- .../extensions/ai-proxy/config/config.go | 3 +- plugins/wasm-go/extensions/ai-proxy/main.go | 26 ++- .../extensions/ai-proxy/provider/failover.go | 154 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 22 ++- .../extensions/ai-proxy/provider/qwen.go | 2 +- 5 files changed, 197 insertions(+), 10 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/failover.go diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index e1bba64027..1a9f9ab62c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -1,9 +1,8 @@ package config import ( - "github.com/tidwall/gjson" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" + "github.com/tidwall/gjson" ) // @Name ai-proxy diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 56e9f58089..1c638e9c05 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -38,8 +38,6 @@ func main() { } func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error { - // log.Debugf("loading config: %s", json.String()) - pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { return err @@ -47,6 +45,10 @@ func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapp if err := pluginConfig.Complete(); err != nil { return err } + + providerConfig := pluginConfig.GetProviderConfig() + providerConfig.SetApiTokensFailover(log) + return nil } @@ -72,8 +74,18 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf ctx.SetContext(ctxKeyApiName, apiName) if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { - // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. + // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() + + log.Debugf("ApiTokens: %s, UnavailableApiTokens: %s", provider.ApiTokens, provider.UnavailableApiTokens) + apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check") + if apiTokenHealthCheck != "" { + ctx.SetContext(provider.ApiTokenInUse, apiTokenHealthCheck) + } else { + providerConfig := pluginConfig.GetProviderConfig() + ctx.SetContext(provider.ApiTokenInUse, providerConfig.GetRandomToken()) + } + hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { @@ -85,6 +97,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } return action } + _ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err)) return types.ActionContinue } @@ -145,6 +158,13 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() + + if ctx.GetContext(provider.ApiTokenHealthCheck) == nil { + unavailableApiToken := ctx.GetContext(provider.ApiTokenInUse).(string) + providerConfig := pluginConfig.GetProviderConfig() + providerConfig.HandleUnavailableApiToken(unavailableApiToken, log) + } + return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go new file mode 100644 index 0000000000..3598dfb412 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -0,0 +1,154 @@ +package provider + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type failover struct { + // @Title zh-CN 是否启用 apiToken 的 failover 机制 + enabled bool `required:"true" yaml:"enabled" json:"enabled"` + // @Title zh-CN 触发 failover 的失败阈值 + failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"` + // @Title zh-CN 健康检测的成功阈值 + successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"` + // @Title zh-CN 健康检测的间隔时间,单位毫秒 + healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"` + // @Title zh-CN 健康检测的超时时间,单位毫秒 + healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` + // @Title zh-CN 健康检测使用的模型 + healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` +} + +var ( + ApiTokens []string + ApiTokenRequestFailureCount = make(map[string]int64) + ApiTokenRequestSuccessCount = make(map[string]int64) + healthCheckClient wrapper.HttpClient + UnavailableApiTokens []string +) + +const ( + ApiTokenInUse = "apiTokenInUse" + ApiTokenHealthCheck = "apiTokenHealthCheck" +) + +func (f *failover) FromJson(json gjson.Result) { + f.enabled = json.Get("enabled").Bool() + f.failureThreshold = json.Get("failureThreshold").Int() + if f.failureThreshold == 0 { + f.failureThreshold = 3 + } + f.successThreshold = json.Get("successThreshold").Int() + if f.successThreshold == 0 { + f.successThreshold = 1 + } + f.healthCheckInterval = json.Get("healthCheckInterval").Int() + if f.healthCheckInterval == 0 { + f.healthCheckInterval = 5000 + } + f.healthCheckTimeout = json.Get("healthCheckTimeout").Int() + if f.healthCheckTimeout == 0 { + f.healthCheckTimeout = 5000 + } + f.healthCheckModel = json.Get("healthCheckModel").String() +} + +func (f *failover) Validate() error { + if f.healthCheckModel == "" { + return errors.New("missing healthCheckModel in failover config") + } + return nil +} + +func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { + ApiTokens = c.apiTokens + // TODO: 目前需要手动加一个 cluster 指向本地的地址,健康检测需要访问该地址 + healthCheckClient = wrapper.NewClusterClient(wrapper.StaticIpCluster{ + ServiceName: "local_cluster", + Port: 10000, + }) + + if c.failover != nil && c.failover.enabled { + wrapper.RegisteTickFunc(c.failover.healthCheckTimeout, func() { + if len(UnavailableApiTokens) > 0 { + for _, apiToken := range UnavailableApiTokens { + log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(UnavailableApiTokens, ", ")) + + path := "/v1/chat/completions" + headers := [][2]string{ + {"Content-Type", "application/json"}, + {"ApiToken-Health-Check", apiToken}, + } + body := []byte(fmt.Sprintf(`{ + "model": "%s", + "messages": [ + { + "role": "user", + "content": "who are you?" + } + ] + }`, c.failover.healthCheckModel)) + err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode == 200 { + c.HandleAvailableApiToken(apiToken, log) + } + }, uint32(c.failover.healthCheckTimeout)) + if err != nil { + log.Errorf("Failed to perform health check request: %v", err) + } + } + } + }) + } +} + +func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) { + ApiTokenRequestSuccessCount[apiToken]++ + if ApiTokenRequestSuccessCount[apiToken] >= c.failover.successThreshold { + log.Infof("apiToken %s is available now, add it back to the list", apiToken) + c.RemoveToken(&UnavailableApiTokens, apiToken) + c.AddToken(&ApiTokens, apiToken) + ApiTokenRequestSuccessCount[apiToken] = 0 + } +} + +func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) { + ApiTokenRequestFailureCount[apiToken]++ + if ApiTokenRequestFailureCount[apiToken] >= c.failover.failureThreshold { + log.Errorf("Remove unavailable apiToken from list: %s", apiToken) + c.RemoveToken(&ApiTokens, apiToken) + c.AddToken(&UnavailableApiTokens, apiToken) + ApiTokenRequestFailureCount[apiToken] = 0 + } +} + +func (c *ProviderConfig) RemoveToken(tokens *[]string, apiToken string) { + tmp := make([]string, 0) + for _, v := range *tokens { + if v != apiToken { + tmp = append(tmp, v) + } + } + *tokens = tmp +} + +func (c *ProviderConfig) AddToken(tokens *[]string, apiToken string) { + if !contains(*tokens, apiToken) { + *tokens = append(*tokens, apiToken) + } +} + +func contains(slice []string, element string) bool { + for _, v := range slice { + if v == element { + return true + } + } + return false +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index b3f29feda5..e654bcb487 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -129,6 +129,9 @@ type ProviderConfig struct { // @Title zh-CN 请求超时 // @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟 timeout uint32 `required:"false" yaml:"timeout" json:"timeout"` + // @Title zh-CN apiToken 故障切换 + // @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表 + failover *failover `required:"false" yaml:"failover" json:"failover"` // @Title zh-CN 基于OpenAI协议的自定义后端URL // @Description zh-CN 仅适用于支持 openai 协议的服务。 openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"` @@ -262,6 +265,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } } } + + failoverJson := json.Get("failover") + if failoverJson.Exists() { + c.failover = &failover{} + c.failover.FromJson(failoverJson) + } } func (c *ProviderConfig) Validate() error { @@ -277,6 +286,12 @@ func (c *ProviderConfig) Validate() error { } } + if c.failover != nil { + if err := c.failover.Validate(); err != nil { + return err + } + } + if c.typ == "" { return errors.New("missing type in provider config") } @@ -300,15 +315,14 @@ func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) strin } func (c *ProviderConfig) GetRandomToken() string { - apiTokens := c.apiTokens - count := len(apiTokens) + count := len(ApiTokens) switch count { case 0: return "" case 1: - return apiTokens[0] + return ApiTokens[0] default: - return apiTokens[rand.Intn(count)] + return ApiTokens[rand.Intn(count)] } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index ffa5be2e37..62c07c1213 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -78,7 +78,7 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(qwenDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) if m.config.protocol == protocolOriginal { return types.ActionContinue, nil From 4af200cfbdf1d39d2d09c93067b39d9f21a3ef82 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 31 Aug 2024 23:23:13 +0800 Subject: [PATCH 02/31] Use SetSharedData for leader election and syncing apiTokens between Wasm VMs --- plugins/wasm-go/extensions/ai-proxy/main.go | 24 +- .../extensions/ai-proxy/provider/failover.go | 323 +++++++++++++++--- .../extensions/ai-proxy/provider/provider.go | 7 +- 3 files changed, 297 insertions(+), 57 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 1c638e9c05..b8fd8c70e2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -77,14 +77,19 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() - log.Debugf("ApiTokens: %s, UnavailableApiTokens: %s", provider.ApiTokens, provider.UnavailableApiTokens) - apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check") - if apiTokenHealthCheck != "" { - ctx.SetContext(provider.ApiTokenInUse, apiTokenHealthCheck) - } else { - providerConfig := pluginConfig.GetProviderConfig() - ctx.SetContext(provider.ApiTokenInUse, providerConfig.GetRandomToken()) + providerConfig := pluginConfig.GetProviderConfig() + apiTokenInUse := providerConfig.GetRandomToken() + if providerConfig.IsFailoverEnabled() { + // Use the health check token if it is a health check request. + if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" { + apiTokenInUse = apiTokenHealthCheck + } else { + // if enable apiToken failover, only use available apiToken + apiTokenInUse = providerConfig.GetGlobalRandomToken(log) + } } + log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiTokenInUse) + ctx.SetContext(provider.ApiTokenInUse, apiTokenInUse) hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) @@ -159,9 +164,10 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo } ctx.DontReadResponseBody() - if ctx.GetContext(provider.ApiTokenHealthCheck) == nil { + providerConfig := pluginConfig.GetProviderConfig() + // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. + if providerConfig.IsFailoverEnabled() && ctx.GetContext(provider.ApiTokenHealthCheck) == nil { unavailableApiToken := ctx.GetContext(provider.ApiTokenInUse).(string) - providerConfig := pluginConfig.GetProviderConfig() providerConfig.HandleUnavailableApiToken(unavailableApiToken, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 3598dfb412..4dc0dfdf36 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -1,12 +1,18 @@ package provider import ( + "encoding/binary" "errors" "fmt" + "math/rand" "net/http" + "strconv" "strings" + "time" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" ) @@ -25,17 +31,27 @@ type failover struct { healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` } +type lease struct { + vmID string + timestamp int64 +} + var ( - ApiTokens []string - ApiTokenRequestFailureCount = make(map[string]int64) - ApiTokenRequestSuccessCount = make(map[string]int64) - healthCheckClient wrapper.HttpClient - UnavailableApiTokens []string + healthCheckClient wrapper.HttpClient ) const ( ApiTokenInUse = "apiTokenInUse" ApiTokenHealthCheck = "apiTokenHealthCheck" + vmLease = "vmLease" + // The length of vmID generated by generateVMID is fixed to 16 bytes + vmIDLength = 16 + // The timestamp is 8 bytes (int64) + leaseLength = vmIDLength + 8 + ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" + ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" + ctxApiTokens = "apiTokens" + ctxUnavailableApiTokens = "unavailableApiTokens" ) func (f *failover) FromJson(json gjson.Result) { @@ -67,25 +83,38 @@ func (f *failover) Validate() error { } func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { - ApiTokens = c.apiTokens // TODO: 目前需要手动加一个 cluster 指向本地的地址,健康检测需要访问该地址 healthCheckClient = wrapper.NewClusterClient(wrapper.StaticIpCluster{ ServiceName: "local_cluster", Port: 10000, }) + vmID := generateVMID() + err := c.initApiTokens() + if err != nil { + log.Errorf("Failed to init apiTokens: %v", err) + } + if c.failover != nil && c.failover.enabled { wrapper.RegisteTickFunc(c.failover.healthCheckTimeout, func() { - if len(UnavailableApiTokens) > 0 { - for _, apiToken := range UnavailableApiTokens { - log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(UnavailableApiTokens, ", ")) - - path := "/v1/chat/completions" - headers := [][2]string{ - {"Content-Type", "application/json"}, - {"ApiToken-Health-Check", apiToken}, - } - body := []byte(fmt.Sprintf(`{ + // Only the Wasm VM that successfully acquires the lease will perform health check + if tryAcquireOrRenewLease(vmID, log) { + log.Debugf("Successfully acquired or renewed lease: %s", vmID) + unavailableTokens, _, err := getApiTokens(ctxUnavailableApiTokens) + if err != nil { + log.Errorf("Failed to get unavailable tokens: %v", err) + return + } + if len(unavailableTokens) > 0 { + for _, apiToken := range unavailableTokens { + log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) + + path := "/v1/chat/completions" + headers := [][2]string{ + {"Content-Type", "application/json"}, + {"ApiToken-Health-Check", apiToken}, + } + body := []byte(fmt.Sprintf(`{ "model": "%s", "messages": [ { @@ -94,13 +123,14 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { } ] }`, c.failover.healthCheckModel)) - err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode == 200 { - c.HandleAvailableApiToken(apiToken, log) + err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode == 200 { + c.HandleAvailableApiToken(apiToken, log) + } + }, uint32(c.failover.healthCheckTimeout)) + if err != nil { + log.Errorf("Failed to perform health check request: %v", err) } - }, uint32(c.failover.healthCheckTimeout)) - if err != nil { - log.Errorf("Failed to perform health check request: %v", err) } } } @@ -108,47 +138,250 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { } } +func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { + data, cas, err := proxywasm.GetSharedData(vmLease) + if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + log.Errorf("Failed to get lease: %v", err) + return false + } + + now := time.Now().Unix() + if data == nil { + return setLease(vmID, now, cas, log) + } + + leaseData := leaseFromBytes(data) + // If vmID is itself, try to renew the lease directly + // If the lease is expired, try to acquire the lease + if leaseData.vmID == vmID || now-leaseData.timestamp > 60 { + leaseData.vmID = vmID + leaseData.timestamp = now + return setLease(vmID, now, cas, log) + } + + return false +} + +func setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { + leaseData := lease{ + vmID: vmID, + timestamp: timestamp, + } + if err := proxywasm.SetSharedData(vmLease, leaseData.toBytes(), cas); err != nil { + log.Errorf("Failed to set or renew lease: %v", err) + return false + } + return true +} + +func (l *lease) toBytes() []byte { + b := make([]byte, leaseLength) + copy(b[:vmIDLength], l.vmID) + binary.LittleEndian.PutUint64(b[vmIDLength:], uint64(l.timestamp)) + + return b +} + +func leaseFromBytes(b []byte) *lease { + if len(b) != leaseLength { + return nil + } + + return &lease{ + vmID: string(b[:vmIDLength]), + timestamp: int64(binary.LittleEndian.Uint64(b[vmIDLength:])), + } +} + +func generateVMID() string { + return fmt.Sprintf("%016x", time.Now().Nanosecond()) +} + +// When number of request successes exceeds the threshold during health check, +// add the apiToken back to the available list and remove it from the unavailable list func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) { - ApiTokenRequestSuccessCount[apiToken]++ - if ApiTokenRequestSuccessCount[apiToken] >= c.failover.successThreshold { + successCount, successCountCas, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) + if err != nil { + log.Errorf("Failed to get apiToken health check success count: %v", err) + return + } + + successCount[apiToken]++ + + if successCount[apiToken] >= c.failover.successThreshold { + unavailableTokens, unavailableTokensCas, err := getApiTokens(ctxUnavailableApiTokens) + if err != nil { + log.Errorf("Failed to get unavailable apiToken: %v", err) + return + } log.Infof("apiToken %s is available now, add it back to the list", apiToken) - c.RemoveToken(&UnavailableApiTokens, apiToken) - c.AddToken(&ApiTokens, apiToken) - ApiTokenRequestSuccessCount[apiToken] = 0 + c.removeApiToken(ctxUnavailableApiTokens, apiToken, unavailableTokens, unavailableTokensCas, log) + + availableTokens, availableCas, err := getApiTokens(ctxApiTokens) + if err != nil { + log.Errorf("Failed to get available apiToken: %v", err) + return + } + c.addApiToken(ctxApiTokens, apiToken, availableTokens, availableCas, log) + c.resetApiTokenRequestCounter(ctxApiTokenRequestSuccessCount, apiToken, successCount, successCountCas, log) + } else { + setApiTokenRequestCount(ctxApiTokenRequestSuccessCount, successCount, successCountCas, log) } } +// When number of request failures exceeds the threshold, +// remove the apiToken from the available list and add it to the unavailable list func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) { - ApiTokenRequestFailureCount[apiToken]++ - if ApiTokenRequestFailureCount[apiToken] >= c.failover.failureThreshold { - log.Errorf("Remove unavailable apiToken from list: %s", apiToken) - c.RemoveToken(&ApiTokens, apiToken) - c.AddToken(&UnavailableApiTokens, apiToken) - ApiTokenRequestFailureCount[apiToken] = 0 + failureCount, failureCountCas, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("Failed to get apiToken request failure count: %v", err) + return + } + + availableTokens, availableTokensCas, err := getApiTokens(ctxApiTokens) + if err != nil { + log.Errorf("Failed to get available apiToken: %v", err) + return + } + // unavailable apiToken has been removed from the available list + if !containsElement(availableTokens, apiToken) { + return + } + + failureCount[apiToken]++ + if failureCount[apiToken] >= c.failover.failureThreshold { + log.Infof("Remove unavailable apiToken from list: %s", apiToken) + c.removeApiToken(ctxApiTokens, apiToken, availableTokens, availableTokensCas, log) + + unavailableTokens, unavailableCas, err := getApiTokens(ctxUnavailableApiTokens) + if err != nil { + log.Errorf("Failed to get unavailable apiToken: %v", err) + return + } + c.addApiToken(ctxUnavailableApiTokens, apiToken, unavailableTokens, unavailableCas, log) + c.resetApiTokenRequestCounter(ctxApiTokenRequestFailureCount, apiToken, failureCount, failureCountCas, log) + } else { + setApiTokenRequestCount(ctxApiTokenRequestFailureCount, failureCount, failureCountCas, log) } } -func (c *ProviderConfig) RemoveToken(tokens *[]string, apiToken string) { - tmp := make([]string, 0) - for _, v := range *tokens { - if v != apiToken { - tmp = append(tmp, v) +func (c *ProviderConfig) removeApiToken(key string, apiToken string, tokens []string, cas uint32, log wrapper.Log) { + err := setApiTokens(key, removeElement(tokens, apiToken), cas) + if err != nil { + log.Errorf("Failed to remove %s from %s, err: %v", apiToken, key, err) + } +} + +func (c *ProviderConfig) resetApiTokenRequestCounter(key string, apiToken string, tokens map[string]int64, cas uint32, log wrapper.Log) { + delete(tokens, apiToken) + setApiTokenRequestCount(key, tokens, cas, log) +} + +func (c *ProviderConfig) addApiToken(key string, apiToken string, tokens []string, cas uint32, log wrapper.Log) { + if !containsElement(tokens, apiToken) { + tokens = append(tokens, apiToken) + err := setApiTokens(key, tokens, cas) + if err != nil { + log.Errorf("Failed to add %s to %s, err: %v", apiToken, key, err) } } - *tokens = tmp } -func (c *ProviderConfig) AddToken(tokens *[]string, apiToken string) { - if !contains(*tokens, apiToken) { - *tokens = append(*tokens, apiToken) +func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + return nil, 0, err } + tokens := make(map[string]int64) + if len(data) > 0 { + pairs := strings.Split(string(data), "\n") + for _, pair := range pairs { + kv := strings.Split(pair, "\x00") + if len(kv) == 2 { + value, _ := strconv.ParseInt(kv[1], 10, 64) + tokens[kv[0]] = value + } + } + } + return tokens, cas, nil } -func contains(slice []string, element string) bool { - for _, v := range slice { - if v == element { +func setApiTokenRequestCount(key string, tokens map[string]int64, cas uint32, log wrapper.Log) { + var pairs []string + for k, v := range tokens { + // use a special character "\x00" to separate key (token) and value (failure or success count), + // in order to retrieve tokens from byte in getApiTokenRequestCount + pair := fmt.Sprintf("%s\x00%d", k, v) + pairs = append(pairs, pair) + } + // use a special character "\n" to separate tokens + data := strings.Join(pairs, "\n") + err := proxywasm.SetSharedData(key, []byte(data), cas) + if err != nil { + log.Errorf("Failed to update %s, err: %v", key, err) + } +} + +func getApiTokens(key string) ([]string, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + return nil, 0, err + } + var tokens []string + if len(data) > 0 { + tokens = strings.Split(string(data), "\n") + } + return tokens, cas, nil +} + +func setApiTokens(key string, tokens []string, cas uint32) error { + // use a special character "\n" to separate tokens, in order to retrieve tokens from byte in getApiTokens + data := strings.Join(tokens, "\n") + return proxywasm.SetSharedData(key, []byte(data), cas) +} + +func removeElement(slice []string, s string) []string { + for i := 0; i < len(slice); i++ { + if slice[i] == s { + slice = append(slice[:i], slice[i+1:]...) + i-- + } + } + return slice +} + +func containsElement(slice []string, s string) bool { + for _, item := range slice { + if item == s { return true } } return false } + +func (c *ProviderConfig) initApiTokens() error { + return setApiTokens(ctxApiTokens, c.apiTokens, 0) +} + +func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { + apiTokens, _, err := getApiTokens(ctxApiTokens) + unavailableApiTokens, _, err := getApiTokens(ctxUnavailableApiTokens) + log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens) + + if err != nil { + return "" + } + count := len(apiTokens) + switch count { + case 0: + return "" + case 1: + return apiTokens[0] + default: + return apiTokens[rand.Intn(count)] + } +} + +func (c *ProviderConfig) IsFailoverEnabled() bool { + return c.failover != nil && c.failover.enabled +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index e654bcb487..f672de6350 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -315,14 +315,15 @@ func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) strin } func (c *ProviderConfig) GetRandomToken() string { - count := len(ApiTokens) + apiTokens := c.apiTokens + count := len(apiTokens) switch count { case 0: return "" case 1: - return ApiTokens[0] + return apiTokens[0] default: - return ApiTokens[rand.Intn(count)] + return apiTokens[rand.Intn(count)] } } From 856343cfa66bccff25662eabf402ead035a9dbf8 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sun, 1 Sep 2024 09:38:48 +0800 Subject: [PATCH 03/31] support failover for all models --- plugins/wasm-go/extensions/ai-proxy/provider/ai360.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/baidu.go | 8 ++++---- plugins/wasm-go/extensions/ai-proxy/provider/claude.go | 2 +- .../wasm-go/extensions/ai-proxy/provider/cloudflare.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/deepl.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/gemini.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/groq.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/minimax.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/mistral.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/openai.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/spark.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/yi.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go | 2 +- 17 files changed, 20 insertions(+), 20 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 00443fcf5e..ec116125b5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -49,7 +49,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam _ = util.OverwriteRequestHost(ai360Domain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken()) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", ctx.GetContext(ApiTokenInUse).(string)) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index c16a8e4395..2a01f079f2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -49,7 +49,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(baichuanChatCompletionPath) _ = util.OverwriteRequestHost(baichuanDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index fc779d5306..a5b39c8d2c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -83,7 +83,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, return types.ActionContinue, errors.New("request model is empty") } // 根据模型重写requestPath - path := b.getRequestPath(request.Model) + path := b.getRequestPath(ctx, request.Model) _ = util.OverwriteRequestPath(path) if b.config.context == nil { @@ -126,7 +126,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } request.Model = mappedModel ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := b.getRequestPath(mappedModel) + path := b.getRequestPath(ctx, mappedModel) _ = util.OverwriteRequestPath(path) if b.config.context == nil { @@ -226,13 +226,13 @@ type baiduTextGenRequest struct { UserId string `json:"user_id,omitempty"` } -func (b *baiduProvider) getRequestPath(baiduModel string) string { +func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix, ok := baiduModelToPathSuffixMap[baiduModel] if !ok { suffix = baiduModel } - return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken()) + return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, ctx.GetContext(ApiTokenInUse).(string)) } func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 7bbbc93d79..3fdee70083 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -108,7 +108,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestPath(claudeChatCompletionPath) _ = util.OverwriteRequestHost(claudeDomain) - _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken()) + _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", ctx.GetContext(ApiTokenInUse).(string)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 35f6f2dc78..720349c835 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -49,7 +49,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) _ = util.OverwriteRequestHost(cloudflareDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 924746c8c9..dc40c621d4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -79,7 +79,7 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestPath(deeplChatCompletionPath) - _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") return types.HeaderStopIteration, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 8cb71462d2..99e778fb3f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -49,7 +49,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(deepseekChatCompletionPath) _ = util.OverwriteRequestHost(deepseekDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 0d418c16a5..65596b4d2b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -52,7 +52,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa return types.ActionContinue, errUnsupportedApiName } - _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken()) + _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, ctx.GetContext(ApiTokenInUse).(string)) _ = util.OverwriteRequestHost(geminiDomain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 644e450ee9..b41427a196 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -47,7 +47,7 @@ func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } _ = util.OverwriteRequestPath(groqChatCompletionPath) _ = util.OverwriteRequestHost(groqDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index ded72d7b51..ceec72d8f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -79,7 +79,7 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(minimaxDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") // Delay the header processing to allow changing streaming mode in OnRequestBody diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index b217d8019e..7000425366 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -44,7 +44,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(mistralDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 6023b4abe8..8dfb56f4e5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -60,7 +60,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(moonshotChatCompletionPath) _ = util.OverwriteRequestHost(moonshotDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 9f34932c1a..2ff8ea69a5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -74,7 +74,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestHost(m.customDomain) } if len(m.config.apiTokens) > 0 { - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) } _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index fc266dfbaa..bc6a4fd328 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -73,7 +73,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } _ = util.OverwriteRequestHost(sparkHost) _ = util.OverwriteRequestPath(sparkChatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index dd6792ed65..b593cb9343 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -47,7 +47,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(stepfunChatCompletionPath) _ = util.OverwriteRequestHost(stepfunDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 287945d903..a0b6533ad6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -47,7 +47,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, } _ = util.OverwriteRequestPath(yiChatCompletionPath) _ = util.OverwriteRequestHost(yiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 9640cd02f4..6178d862c1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -46,7 +46,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) _ = util.OverwriteRequestHost(zhipuAiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } From 7d5f427982c068d60e4807788368abf703430500 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 7 Sep 2024 13:31:40 +0800 Subject: [PATCH 04/31] add cas retry logic --- .../extensions/ai-proxy/provider/failover.go | 223 ++++++++++-------- 1 file changed, 131 insertions(+), 92 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 4dc0dfdf36..5eb9be274e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -47,11 +47,16 @@ const ( // The length of vmID generated by generateVMID is fixed to 16 bytes vmIDLength = 16 // The timestamp is 8 bytes (int64) - leaseLength = vmIDLength + 8 - ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" - ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" - ctxApiTokens = "apiTokens" - ctxUnavailableApiTokens = "unavailableApiTokens" + leaseLength = vmIDLength + 8 + ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" + ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" + ctxApiTokens = "apiTokens" + ctxUnavailableApiTokens = "unavailableApiTokens" + casMaxRetries = 10 + addApiTokenOperation = "addApiToken" + removeApiTokenOperation = "removeApiToken" + addApiTokenRequestCountOperation = "addApiTokenRequestCount" + resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" ) func (f *failover) FromJson(json gjson.Result) { @@ -167,14 +172,14 @@ func setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { vmID: vmID, timestamp: timestamp, } - if err := proxywasm.SetSharedData(vmLease, leaseData.toBytes(), cas); err != nil { + if err := proxywasm.SetSharedData(vmLease, leaseData.leaseToBytes(), cas); err != nil { log.Errorf("Failed to set or renew lease: %v", err) return false } return true } -func (l *lease) toBytes() []byte { +func (l *lease) leaseToBytes() []byte { b := make([]byte, leaseLength) copy(b[:vmIDLength], l.vmID) binary.LittleEndian.PutUint64(b[vmIDLength:], uint64(l.timestamp)) @@ -200,45 +205,34 @@ func generateVMID() string { // When number of request successes exceeds the threshold during health check, // add the apiToken back to the available list and remove it from the unavailable list func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) { - successCount, successCountCas, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) + successApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) if err != nil { - log.Errorf("Failed to get apiToken health check success count: %v", err) + log.Errorf("Failed to get successApiTokenRequestCount: %v", err) return } - successCount[apiToken]++ - - if successCount[apiToken] >= c.failover.successThreshold { - unavailableTokens, unavailableTokensCas, err := getApiTokens(ctxUnavailableApiTokens) - if err != nil { - log.Errorf("Failed to get unavailable apiToken: %v", err) - return - } - log.Infof("apiToken %s is available now, add it back to the list", apiToken) - c.removeApiToken(ctxUnavailableApiTokens, apiToken, unavailableTokens, unavailableTokensCas, log) - - availableTokens, availableCas, err := getApiTokens(ctxApiTokens) - if err != nil { - log.Errorf("Failed to get available apiToken: %v", err) - return - } - c.addApiToken(ctxApiTokens, apiToken, availableTokens, availableCas, log) - c.resetApiTokenRequestCounter(ctxApiTokenRequestSuccessCount, apiToken, successCount, successCountCas, log) + successCount := successApiTokenRequestCount[apiToken] + 1 + if successCount >= c.failover.successThreshold { + log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken) + removeApiToken(ctxUnavailableApiTokens, apiToken, log) + addApiToken(ctxApiTokens, apiToken, log) + resetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) } else { - setApiTokenRequestCount(ctxApiTokenRequestSuccessCount, successCount, successCountCas, log) + log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check......", apiToken, successCount) + addApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) } } // When number of request failures exceeds the threshold, // remove the apiToken from the available list and add it to the unavailable list func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) { - failureCount, failureCountCas, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) if err != nil { - log.Errorf("Failed to get apiToken request failure count: %v", err) + log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) return } - availableTokens, availableTokensCas, err := getApiTokens(ctxApiTokens) + availableTokens, _, err := getApiTokens(ctxApiTokens) if err != nil { log.Errorf("Failed to get available apiToken: %v", err) return @@ -248,77 +242,58 @@ func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper. return } - failureCount[apiToken]++ - if failureCount[apiToken] >= c.failover.failureThreshold { - log.Infof("Remove unavailable apiToken from list: %s", apiToken) - c.removeApiToken(ctxApiTokens, apiToken, availableTokens, availableTokensCas, log) - - unavailableTokens, unavailableCas, err := getApiTokens(ctxUnavailableApiTokens) - if err != nil { - log.Errorf("Failed to get unavailable apiToken: %v", err) - return - } - c.addApiToken(ctxUnavailableApiTokens, apiToken, unavailableTokens, unavailableCas, log) - c.resetApiTokenRequestCounter(ctxApiTokenRequestFailureCount, apiToken, failureCount, failureCountCas, log) + failureCount := failureApiTokenRequestCount[apiToken] + 1 + if failureCount >= c.failover.failureThreshold { + log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken) + removeApiToken(ctxApiTokens, apiToken, log) + addApiToken(ctxUnavailableApiTokens, apiToken, log) + resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) } else { - setApiTokenRequestCount(ctxApiTokenRequestFailureCount, failureCount, failureCountCas, log) + log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) + addApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) } } -func (c *ProviderConfig) removeApiToken(key string, apiToken string, tokens []string, cas uint32, log wrapper.Log) { - err := setApiTokens(key, removeElement(tokens, apiToken), cas) - if err != nil { - log.Errorf("Failed to remove %s from %s, err: %v", apiToken, key, err) - } +func addApiToken(key, apiToken string, log wrapper.Log) { + modifyApiToken(key, apiToken, addApiTokenOperation, log) } -func (c *ProviderConfig) resetApiTokenRequestCounter(key string, apiToken string, tokens map[string]int64, cas uint32, log wrapper.Log) { - delete(tokens, apiToken) - setApiTokenRequestCount(key, tokens, cas, log) +func removeApiToken(key, apiToken string, log wrapper.Log) { + modifyApiToken(key, apiToken, removeApiTokenOperation, log) } -func (c *ProviderConfig) addApiToken(key string, apiToken string, tokens []string, cas uint32, log wrapper.Log) { - if !containsElement(tokens, apiToken) { - tokens = append(tokens, apiToken) - err := setApiTokens(key, tokens, cas) +func modifyApiToken(key, apiToken, op string, log wrapper.Log) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + apiTokens, cas, err := getApiTokens(key) if err != nil { - log.Errorf("Failed to add %s to %s, err: %v", apiToken, key, err) + log.Errorf("Failed to get %s: %v", key, err) + continue } - } -} -func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { - data, cas, err := proxywasm.GetSharedData(key) - if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { - return nil, 0, err - } - tokens := make(map[string]int64) - if len(data) > 0 { - pairs := strings.Split(string(data), "\n") - for _, pair := range pairs { - kv := strings.Split(pair, "\x00") - if len(kv) == 2 { - value, _ := strconv.ParseInt(kv[1], 10, 64) - tokens[kv[0]] = value - } + exists := containsElement(apiTokens, apiToken) + if op == addApiTokenOperation && exists { + log.Debugf("%s already exists in %s", apiToken, key) + return + } else if op == removeApiTokenOperation && !exists { + log.Debugf("%s does not exist in %s", apiToken, key) + return } - } - return tokens, cas, nil -} -func setApiTokenRequestCount(key string, tokens map[string]int64, cas uint32, log wrapper.Log) { - var pairs []string - for k, v := range tokens { - // use a special character "\x00" to separate key (token) and value (failure or success count), - // in order to retrieve tokens from byte in getApiTokenRequestCount - pair := fmt.Sprintf("%s\x00%d", k, v) - pairs = append(pairs, pair) - } - // use a special character "\n" to separate tokens - data := strings.Join(pairs, "\n") - err := proxywasm.SetSharedData(key, []byte(data), cas) - if err != nil { - log.Errorf("Failed to update %s, err: %v", key, err) + if op == addApiTokenOperation { + apiTokens = append(apiTokens, apiToken) + } else { + apiTokens = removeElement(apiTokens, apiToken) + } + + if err := setApiTokens(key, apiTokens, cas); err == nil { + log.Debugf("Successfully updated %s in %s", apiToken, key) + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + + log.Errorf("CAS mismatch when setting %s, retrying...", key) } } @@ -327,11 +302,11 @@ func getApiTokens(key string) ([]string, uint32, error) { if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { return nil, 0, err } - var tokens []string + var apiTokens []string if len(data) > 0 { - tokens = strings.Split(string(data), "\n") + apiTokens = strings.Split(string(data), "\n") } - return tokens, cas, nil + return apiTokens, cas, nil } func setApiTokens(key string, tokens []string, cas uint32) error { @@ -359,6 +334,70 @@ func containsElement(slice []string, s string) bool { return false } +func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + return nil, 0, err + } + apiTokens := make(map[string]int64) + if len(data) > 0 { + pairs := strings.Split(string(data), "\n") + for _, pair := range pairs { + kv := strings.Split(pair, "\x00") + if len(kv) == 2 { + value, _ := strconv.ParseInt(kv[1], 10, 64) + apiTokens[kv[0]] = value + } + } + } + return apiTokens, cas, nil +} + +func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) { + modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log) +} + +func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { + modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log) +} + +func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + apiTokenRequestCount, cas, err := getApiTokenRequestCount(key) + if err != nil { + log.Errorf("Failed to get %s: %v", key, err) + continue + } + + if op == resetApiTokenRequestCountOperation { + delete(apiTokenRequestCount, apiToken) + } else { + apiTokenRequestCount[apiToken]++ + } + + data := apiTokenRequestCountToByte(apiTokenRequestCount) + + if err := proxywasm.SetSharedData(key, data, cas); err == nil { + log.Debugf("Successfully updated the count of %s in %s", apiToken, key) + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + + log.Errorf("CAS mismatch when setting %s, retrying...", key) + } +} + +func apiTokenRequestCountToByte(apiTokenRequestCount map[string]int64) []byte { + var pairs []string + for k, v := range apiTokenRequestCount { + pair := fmt.Sprintf("%s\x00%d", k, v) + pairs = append(pairs, pair) + } + return []byte(strings.Join(pairs, "\n")) +} + func (c *ProviderConfig) initApiTokens() error { return setApiTokens(ctxApiTokens, c.apiTokens, 0) } From ee498483db4f668d47535d197023499f72c75e59 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 7 Sep 2024 13:45:51 +0800 Subject: [PATCH 05/31] wrap getApiTokenInUse funtion --- plugins/wasm-go/extensions/ai-proxy/provider/ai360.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/baidu.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/claude.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/deepl.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/failover.go | 4 ++++ plugins/wasm-go/extensions/ai-proxy/provider/gemini.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/groq.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/minimax.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/mistral.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/openai.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/qwen.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/spark.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/yi.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go | 2 +- 19 files changed, 22 insertions(+), 18 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index ec116125b5..4e49f16df7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -49,7 +49,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam _ = util.OverwriteRequestHost(ai360Domain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", ctx.GetContext(ApiTokenInUse).(string)) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", getApiTokenInUse(ctx)) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 2a01f079f2..c8729ec474 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -49,7 +49,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(baichuanChatCompletionPath) _ = util.OverwriteRequestHost(baichuanDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index a5b39c8d2c..2f0774d9aa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -232,7 +232,7 @@ func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel strin if !ok { suffix = baiduModel } - return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, ctx.GetContext(ApiTokenInUse).(string)) + return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, getApiTokenInUse(ctx)) } func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 3fdee70083..31e02709ea 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -108,7 +108,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestPath(claudeChatCompletionPath) _ = util.OverwriteRequestHost(claudeDomain) - _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", ctx.GetContext(ApiTokenInUse).(string)) + _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", getApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 720349c835..798ae92466 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -49,7 +49,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) _ = util.OverwriteRequestHost(cloudflareDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index dc40c621d4..32d90577f4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -79,7 +79,7 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestPath(deeplChatCompletionPath) - _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") return types.HeaderStopIteration, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 99e778fb3f..e9e9c26fdc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -49,7 +49,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(deepseekChatCompletionPath) _ = util.OverwriteRequestHost(deepseekDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 5eb9be274e..521004bb6d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -421,6 +421,10 @@ func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { } } +func getApiTokenInUse(ctx wrapper.HttpContext) string { + return ctx.GetContext(ApiTokenInUse).(string) +} + func (c *ProviderConfig) IsFailoverEnabled() bool { return c.failover != nil && c.failover.enabled } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 65596b4d2b..4a6683f319 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -52,7 +52,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa return types.ActionContinue, errUnsupportedApiName } - _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, ctx.GetContext(ApiTokenInUse).(string)) + _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, getApiTokenInUse(ctx)) _ = util.OverwriteRequestHost(geminiDomain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index b41427a196..2883e9ee92 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -47,7 +47,7 @@ func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } _ = util.OverwriteRequestPath(groqChatCompletionPath) _ = util.OverwriteRequestHost(groqDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index ceec72d8f3..d3c912bce7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -79,7 +79,7 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(minimaxDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") // Delay the header processing to allow changing streaming mode in OnRequestBody diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 7000425366..bfcf0a97da 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -44,7 +44,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(mistralDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 8dfb56f4e5..9eac6788ba 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -60,7 +60,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(moonshotChatCompletionPath) _ = util.OverwriteRequestHost(moonshotDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2ff8ea69a5..f5ed9270f1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -74,7 +74,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestHost(m.customDomain) } if len(m.config.apiTokens) > 0 { - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) } _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 62c07c1213..e9c1b44514 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -78,7 +78,7 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(qwenDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) if m.config.protocol == protocolOriginal { return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index bc6a4fd328..6f044f8e4f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -73,7 +73,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } _ = util.OverwriteRequestHost(sparkHost) _ = util.OverwriteRequestPath(sparkChatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index b593cb9343..fe445fa8c8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -47,7 +47,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(stepfunChatCompletionPath) _ = util.OverwriteRequestHost(stepfunDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index a0b6533ad6..4631793a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -47,7 +47,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, } _ = util.OverwriteRequestPath(yiChatCompletionPath) _ = util.OverwriteRequestHost(yiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 6178d862c1..8cebc247e1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -46,7 +46,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) _ = util.OverwriteRequestHost(zhipuAiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string)) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } From 1e40d82f881de2540c7e39891653f6fa871368bd Mon Sep 17 00:00:00 2001 From: Se7en Date: Wed, 25 Sep 2024 20:43:23 +0800 Subject: [PATCH 06/31] only removed the apiToken when the number of consecutive request failures exceeds the threshold --- plugins/wasm-go/extensions/ai-proxy/main.go | 11 ++++++++-- .../extensions/ai-proxy/provider/failover.go | 20 +++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index b8fd8c70e2..af260ab630 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -158,6 +158,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType()) status, err := proxywasm.GetHttpResponseHeader(":status") + apiTokenInUse := ctx.GetContext(provider.ApiTokenInUse).(string) if err != nil || status != "200" { if err != nil { log.Errorf("unable to load :status header from response: %v", err) @@ -167,13 +168,19 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo providerConfig := pluginConfig.GetProviderConfig() // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. if providerConfig.IsFailoverEnabled() && ctx.GetContext(provider.ApiTokenHealthCheck) == nil { - unavailableApiToken := ctx.GetContext(provider.ApiTokenInUse).(string) - providerConfig.HandleUnavailableApiToken(unavailableApiToken, log) + providerConfig.HandleUnavailableApiToken(apiTokenInUse, log) } return types.ActionContinue } + // Reset ctxApiTokenRequestFailureCount if the request is successful, + // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. + failureApiTokenRequestCount, _, err := provider.GetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount) + if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { + provider.ResetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount, apiTokenInUse, log) + } + if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseHeaders(ctx, apiName, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 521004bb6d..91926977e3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -48,7 +48,7 @@ const ( vmIDLength = 16 // The timestamp is 8 bytes (int64) leaseLength = vmIDLength + 8 - ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" + CtxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" ctxApiTokens = "apiTokens" ctxUnavailableApiTokens = "unavailableApiTokens" @@ -56,7 +56,7 @@ const ( addApiTokenOperation = "addApiToken" removeApiTokenOperation = "removeApiToken" addApiTokenRequestCountOperation = "addApiTokenRequestCount" - resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" + resetApiTokenRequestCountOperation = "ResetApiTokenRequestCount" ) func (f *failover) FromJson(json gjson.Result) { @@ -205,7 +205,7 @@ func generateVMID() string { // When number of request successes exceeds the threshold during health check, // add the apiToken back to the available list and remove it from the unavailable list func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) { - successApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) + successApiTokenRequestCount, _, err := GetApiTokenRequestCount(ctxApiTokenRequestSuccessCount) if err != nil { log.Errorf("Failed to get successApiTokenRequestCount: %v", err) return @@ -216,7 +216,7 @@ func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Lo log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken) removeApiToken(ctxUnavailableApiTokens, apiToken, log) addApiToken(ctxApiTokens, apiToken, log) - resetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) + ResetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) } else { log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check......", apiToken, successCount) addApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) @@ -226,7 +226,7 @@ func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Lo // When number of request failures exceeds the threshold, // remove the apiToken from the available list and add it to the unavailable list func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) { - failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + failureApiTokenRequestCount, _, err := GetApiTokenRequestCount(CtxApiTokenRequestFailureCount) if err != nil { log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) return @@ -247,10 +247,10 @@ func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper. log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken) removeApiToken(ctxApiTokens, apiToken, log) addApiToken(ctxUnavailableApiTokens, apiToken, log) - resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) + ResetApiTokenRequestCount(CtxApiTokenRequestFailureCount, apiToken, log) } else { log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) - addApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) + addApiTokenRequestCount(CtxApiTokenRequestFailureCount, apiToken, log) } } @@ -334,7 +334,7 @@ func containsElement(slice []string, s string) bool { return false } -func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { +func GetApiTokenRequestCount(key string) (map[string]int64, uint32, error) { data, cas, err := proxywasm.GetSharedData(key) if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { return nil, 0, err @@ -357,13 +357,13 @@ func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) { modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log) } -func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { +func ResetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log) } func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) { for attempt := 1; attempt <= casMaxRetries; attempt++ { - apiTokenRequestCount, cas, err := getApiTokenRequestCount(key) + apiTokenRequestCount, cas, err := GetApiTokenRequestCount(key) if err != nil { log.Errorf("Failed to get %s: %v", key, err) continue From 432395b1ea5d1aeff05e71df10be2983231faf0d Mon Sep 17 00:00:00 2001 From: Se7en Date: Wed, 25 Sep 2024 21:23:04 +0800 Subject: [PATCH 07/31] use uuid as vmid --- .../wasm-go/extensions/ai-proxy/provider/failover.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 91926977e3..d8ec86bf00 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/google/uuid" "math/rand" "net/http" "strconv" @@ -44,8 +45,9 @@ const ( ApiTokenInUse = "apiTokenInUse" ApiTokenHealthCheck = "apiTokenHealthCheck" vmLease = "vmLease" - // The length of vmID generated by generateVMID is fixed to 16 bytes - vmIDLength = 16 + // The length of vmID generated by generateVMID is fixed to 36 bytes + // e.g., 66043227-e5a3-48e7-8e59-9135199367ba + vmIDLength = 36 // The timestamp is 8 bytes (int64) leaseLength = vmIDLength + 8 CtxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" @@ -101,7 +103,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { } if c.failover != nil && c.failover.enabled { - wrapper.RegisteTickFunc(c.failover.healthCheckTimeout, func() { + wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() { // Only the Wasm VM that successfully acquires the lease will perform health check if tryAcquireOrRenewLease(vmID, log) { log.Debugf("Successfully acquired or renewed lease: %s", vmID) @@ -199,7 +201,7 @@ func leaseFromBytes(b []byte) *lease { } func generateVMID() string { - return fmt.Sprintf("%016x", time.Now().Nanosecond()) + return uuid.New().String() } // When number of request successes exceeds the threshold during health check, From 67551f2b8da60c7750918babf5d5a678a6c6af83 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 26 Sep 2024 13:10:39 +0800 Subject: [PATCH 08/31] fix byte covert --- plugins/wasm-go/extensions/ai-proxy/main.go | 3 + .../extensions/ai-proxy/provider/failover.go | 140 ++++++++---------- 2 files changed, 66 insertions(+), 77 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index af260ab630..bcffaa39ce 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -177,6 +177,9 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo // Reset ctxApiTokenRequestFailureCount if the request is successful, // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. failureApiTokenRequestCount, _, err := provider.GetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("failed to get failureApiTokenRequestCount: %v", err) + } if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { provider.ResetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount, apiTokenInUse, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index d8ec86bf00..c7394ecf1f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -1,13 +1,12 @@ package provider import ( - "encoding/binary" + "encoding/json" "errors" "fmt" "github.com/google/uuid" "math/rand" "net/http" - "strconv" "strings" "time" @@ -32,9 +31,9 @@ type failover struct { healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` } -type lease struct { - vmID string - timestamp int64 +type Lease struct { + VMID string `json:"vmID"` + Timestamp int64 `json:"timestamp"` } var ( @@ -42,14 +41,9 @@ var ( ) const ( - ApiTokenInUse = "apiTokenInUse" - ApiTokenHealthCheck = "apiTokenHealthCheck" - vmLease = "vmLease" - // The length of vmID generated by generateVMID is fixed to 36 bytes - // e.g., 66043227-e5a3-48e7-8e59-9135199367ba - vmIDLength = 36 - // The timestamp is 8 bytes (int64) - leaseLength = vmIDLength + 8 + ApiTokenInUse = "apiTokenInUse" + ApiTokenHealthCheck = "apiTokenHealthCheck" + vmLease = "vmLease" CtxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" ctxApiTokens = "apiTokens" @@ -146,23 +140,29 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { } func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { + now := time.Now().Unix() + data, cas, err := proxywasm.GetSharedData(vmLease) - if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { - log.Errorf("Failed to get lease: %v", err) - return false + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return setLease(vmID, now, cas, log) + } else { + log.Errorf("Failed to get lease: %v", err) + return false + } } - now := time.Now().Unix() - if data == nil { - return setLease(vmID, now, cas, log) + var lease Lease + err = json.Unmarshal(data, &lease) + if err != nil { + log.Errorf("Failed to unmarshal lease data: %v", err) + return false } - - leaseData := leaseFromBytes(data) // If vmID is itself, try to renew the lease directly - // If the lease is expired, try to acquire the lease - if leaseData.vmID == vmID || now-leaseData.timestamp > 60 { - leaseData.vmID = vmID - leaseData.timestamp = now + // If the lease is expired (60s), try to acquire the lease + if lease.VMID == vmID || now-lease.Timestamp > 60 { + lease.VMID = vmID + lease.Timestamp = now return setLease(vmID, now, cas, log) } @@ -170,34 +170,21 @@ func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { } func setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { - leaseData := lease{ - vmID: vmID, - timestamp: timestamp, + lease := Lease{ + VMID: vmID, + Timestamp: timestamp, } - if err := proxywasm.SetSharedData(vmLease, leaseData.leaseToBytes(), cas); err != nil { - log.Errorf("Failed to set or renew lease: %v", err) + leaseByte, err := json.Marshal(lease) + if err != nil { + log.Errorf("Failed to marshal lease data: %v", err) return false } - return true -} - -func (l *lease) leaseToBytes() []byte { - b := make([]byte, leaseLength) - copy(b[:vmIDLength], l.vmID) - binary.LittleEndian.PutUint64(b[vmIDLength:], uint64(l.timestamp)) - - return b -} - -func leaseFromBytes(b []byte) *lease { - if len(b) != leaseLength { - return nil - } - return &lease{ - vmID: string(b[:vmIDLength]), - timestamp: int64(binary.LittleEndian.Uint64(b[vmIDLength:])), + if err := proxywasm.SetSharedData(vmLease, leaseByte, cas); err != nil { + log.Errorf("Failed to set or renew lease: %v", err) + return false } + return true } func generateVMID() string { @@ -301,20 +288,27 @@ func modifyApiToken(key, apiToken, op string, log wrapper.Log) { func getApiTokens(key string) ([]string, uint32, error) { data, cas, err := proxywasm.GetSharedData(key) - if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return []string{}, cas, nil + } return nil, 0, err } + var apiTokens []string - if len(data) > 0 { - apiTokens = strings.Split(string(data), "\n") + if err = json.Unmarshal(data, &apiTokens); err != nil { + return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err) } + return apiTokens, cas, nil } -func setApiTokens(key string, tokens []string, cas uint32) error { - // use a special character "\n" to separate tokens, in order to retrieve tokens from byte in getApiTokens - data := strings.Join(tokens, "\n") - return proxywasm.SetSharedData(key, []byte(data), cas) +func setApiTokens(key string, apiTokens []string, cas uint32) error { + data, err := json.Marshal(apiTokens) + if err != nil { + return fmt.Errorf("failed to marshal tokens: %v", err) + } + return proxywasm.SetSharedData(key, data, cas) } func removeElement(slice []string, s string) []string { @@ -338,19 +332,17 @@ func containsElement(slice []string, s string) bool { func GetApiTokenRequestCount(key string) (map[string]int64, uint32, error) { data, cas, err := proxywasm.GetSharedData(key) - if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return make(map[string]int64), cas, nil + } return nil, 0, err } - apiTokens := make(map[string]int64) - if len(data) > 0 { - pairs := strings.Split(string(data), "\n") - for _, pair := range pairs { - kv := strings.Split(pair, "\x00") - if len(kv) == 2 { - value, _ := strconv.ParseInt(kv[1], 10, 64) - apiTokens[kv[0]] = value - } - } + + var apiTokens map[string]int64 + err = json.Unmarshal(data, &apiTokens) + if err != nil { + return nil, 0, err } return apiTokens, cas, nil } @@ -377,9 +369,12 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log apiTokenRequestCount[apiToken]++ } - data := apiTokenRequestCountToByte(apiTokenRequestCount) + apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount) + if err != nil { + log.Errorf("failed to marshal apiTokenRequestCount: %v", err) + } - if err := proxywasm.SetSharedData(key, data, cas); err == nil { + if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil { log.Debugf("Successfully updated the count of %s in %s", apiToken, key) return } else if !errors.Is(err, types.ErrorStatusCasMismatch) { @@ -391,15 +386,6 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log } } -func apiTokenRequestCountToByte(apiTokenRequestCount map[string]int64) []byte { - var pairs []string - for k, v := range apiTokenRequestCount { - pair := fmt.Sprintf("%s\x00%d", k, v) - pairs = append(pairs, pair) - } - return []byte(strings.Join(pairs, "\n")) -} - func (c *ProviderConfig) initApiTokens() error { return setApiTokens(ctxApiTokens, c.apiTokens, 0) } From 82b2284b0d6b321d6332c1e465fae1dded984330 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 26 Sep 2024 13:28:50 +0800 Subject: [PATCH 09/31] reset shared data during initialization --- .../extensions/ai-proxy/provider/failover.go | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index c7394ecf1f..b42c523aa8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -84,6 +84,9 @@ func (f *failover) Validate() error { } func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { + // Reset shared data in case plugin configuration is updated + resetSharedData() + // TODO: 目前需要手动加一个 cluster 指向本地的地址,健康检测需要访问该地址 healthCheckClient = wrapper.NewClusterClient(wrapper.StaticIpCluster{ ServiceName: "local_cluster", @@ -151,6 +154,9 @@ func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { return false } } + if data == nil { + return setLease(vmID, now, cas, log) + } var lease Lease err = json.Unmarshal(data, &lease) @@ -294,6 +300,9 @@ func getApiTokens(key string) ([]string, uint32, error) { } return nil, 0, err } + if data == nil { + return []string{}, cas, nil + } var apiTokens []string if err = json.Unmarshal(data, &apiTokens); err != nil { @@ -339,6 +348,10 @@ func GetApiTokenRequestCount(key string) (map[string]int64, uint32, error) { return nil, 0, err } + if data == nil { + return make(map[string]int64), cas, nil + } + var apiTokens map[string]int64 err = json.Unmarshal(data, &apiTokens) if err != nil { @@ -416,3 +429,11 @@ func getApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) IsFailoverEnabled() bool { return c.failover != nil && c.failover.enabled } + +func resetSharedData() { + _ = proxywasm.SetSharedData(vmLease, nil, 0) + _ = proxywasm.SetSharedData(ctxApiTokens, nil, 0) + _ = proxywasm.SetSharedData(ctxUnavailableApiTokens, nil, 0) + _ = proxywasm.SetSharedData(ctxApiTokenRequestSuccessCount, nil, 0) + _ = proxywasm.SetSharedData(CtxApiTokenRequestFailureCount, nil, 0) +} From 8a818ed5b3e7281ba63abb0ab67c6b66a5e556c2 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 26 Sep 2024 13:39:18 +0800 Subject: [PATCH 10/31] failover support new model --- plugins/wasm-go/extensions/ai-proxy/provider/cohere.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/doubao.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 7ffe1708af..aa6b422593 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -59,7 +59,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } _ = util.OverwriteRequestHost(cohereDomain) _ = util.OverwriteRequestPath(chatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 0ca349a773..85ff3f4a8f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -42,7 +42,7 @@ func (m *doubaoProvider) GetProviderType() string { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = util.OverwriteRequestHost(doubaoDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() From 0554c85fa6468b8d0c5290c634ff725175168618 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 26 Sep 2024 13:48:39 +0800 Subject: [PATCH 11/31] fix --- plugins/wasm-go/extensions/ai-proxy/main.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index dc568746ae..f404be103b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -47,6 +47,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log if err := pluginConfig.Complete(); err != nil { return err } + + providerConfig := pluginConfig.GetProviderConfig() + providerConfig.SetApiTokensFailover(log) + return nil } @@ -63,9 +67,6 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug return err } - providerConfig := pluginConfig.GetProviderConfig() - providerConfig.SetApiTokensFailover(log) - return nil } From e3401d5f02dbc25ba7ffc9d772a802098bb76a68 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 28 Sep 2024 20:37:05 +0800 Subject: [PATCH 12/31] move SetApiTokensFailover to complete function --- plugins/wasm-go/extensions/ai-proxy/config/config.go | 11 +++++++++-- plugins/wasm-go/extensions/ai-proxy/main.go | 7 ++----- .../wasm-go/extensions/ai-proxy/provider/failover.go | 5 +++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index 9019a92c9f..d604630e31 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -2,6 +2,7 @@ package config import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -74,12 +75,18 @@ func (c *PluginConfig) Validate() error { return nil } -func (c *PluginConfig) Complete() error { +func (c *PluginConfig) Complete(log wrapper.Log) error { if c.activeProviderConfig == nil { c.activeProvider = nil return nil } - var err error + + providerConfig := c.GetProviderConfig() + err := providerConfig.SetApiTokensFailover(log) + if err != nil { + return err + } + c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig) return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index f404be103b..9bf230fb5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -44,13 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log if err := pluginConfig.Validate(); err != nil { return err } - if err := pluginConfig.Complete(); err != nil { + if err := pluginConfig.Complete(log); err != nil { return err } - providerConfig := pluginConfig.GetProviderConfig() - providerConfig.SetApiTokensFailover(log) - return nil } @@ -63,7 +60,7 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug if err := pluginConfig.Validate(); err != nil { return err } - if err := pluginConfig.Complete(); err != nil { + if err := pluginConfig.Complete(log); err != nil { return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index b42c523aa8..00d13f439e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -83,7 +83,7 @@ func (f *failover) Validate() error { return nil } -func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { +func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { // Reset shared data in case plugin configuration is updated resetSharedData() @@ -96,7 +96,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { vmID := generateVMID() err := c.initApiTokens() if err != nil { - log.Errorf("Failed to init apiTokens: %v", err) + return fmt.Errorf("failed to init apiTokens: %v", err) } if c.failover != nil && c.failover.enabled { @@ -140,6 +140,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) { } }) } + return nil } func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { From 0f79913d54e64cecdce056f8097cd757ff825bb4 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 28 Sep 2024 21:26:26 +0800 Subject: [PATCH 13/31] wrap failover logic into ProviderConfig --- plugins/wasm-go/extensions/ai-proxy/main.go | 35 ++-------- .../extensions/ai-proxy/provider/failover.go | 68 +++++++++++++------ 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 9bf230fb5c..102cf015c6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -92,20 +92,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() - - providerConfig := pluginConfig.GetProviderConfig() - apiTokenInUse := providerConfig.GetRandomToken() - if providerConfig.IsFailoverEnabled() { - // Use the health check token if it is a health check request. - if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" { - apiTokenInUse = apiTokenHealthCheck - } else { - // if enable apiToken failover, only use available apiToken - apiTokenInUse = providerConfig.GetGlobalRandomToken(log) - } - } - log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiTokenInUse) - ctx.SetContext(provider.ApiTokenInUse, apiTokenInUse) + // Set the apiToken for the current request. + providerConfig.SetApiTokenInUse(ctx, log) hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) @@ -173,32 +161,23 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType()) + providerConfig := pluginConfig.GetProviderConfig() + apiTokenInUse := providerConfig.GetApiTokenInUse(ctx) + status, err := proxywasm.GetHttpResponseHeader(":status") - apiTokenInUse := ctx.GetContext(provider.ApiTokenInUse).(string) if err != nil || status != "200" { if err != nil { log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() - - providerConfig := pluginConfig.GetProviderConfig() - // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. - if providerConfig.IsFailoverEnabled() && ctx.GetContext(provider.ApiTokenHealthCheck) == nil { - providerConfig.HandleUnavailableApiToken(apiTokenInUse, log) - } + providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) return types.ActionContinue } // Reset ctxApiTokenRequestFailureCount if the request is successful, // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. - failureApiTokenRequestCount, _, err := provider.GetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount) - if err != nil { - log.Errorf("failed to get failureApiTokenRequestCount: %v", err) - } - if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { - provider.ResetApiTokenRequestCount(provider.CtxApiTokenRequestFailureCount, apiTokenInUse, log) - } + providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log) if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 00d13f439e..d99146ccfb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -41,10 +41,10 @@ var ( ) const ( - ApiTokenInUse = "apiTokenInUse" - ApiTokenHealthCheck = "apiTokenHealthCheck" + apiTokenInUse = "apiTokenInUse" + apiTokenHealthCheck = "apiTokenHealthCheck" vmLease = "vmLease" - CtxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" + ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" ctxApiTokens = "apiTokens" ctxUnavailableApiTokens = "unavailableApiTokens" @@ -129,7 +129,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { }`, c.failover.healthCheckModel)) err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode == 200 { - c.HandleAvailableApiToken(apiToken, log) + c.handleAvailableApiToken(apiToken, log) } }, uint32(c.failover.healthCheckTimeout)) if err != nil { @@ -200,8 +200,8 @@ func generateVMID() string { // When number of request successes exceeds the threshold during health check, // add the apiToken back to the available list and remove it from the unavailable list -func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Log) { - successApiTokenRequestCount, _, err := GetApiTokenRequestCount(ctxApiTokenRequestSuccessCount) +func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) { + successApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) if err != nil { log.Errorf("Failed to get successApiTokenRequestCount: %v", err) return @@ -212,7 +212,7 @@ func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Lo log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken) removeApiToken(ctxUnavailableApiTokens, apiToken, log) addApiToken(ctxApiTokens, apiToken, log) - ResetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) + resetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) } else { log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check......", apiToken, successCount) addApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) @@ -221,8 +221,8 @@ func (c *ProviderConfig) HandleAvailableApiToken(apiToken string, log wrapper.Lo // When number of request failures exceeds the threshold, // remove the apiToken from the available list and add it to the unavailable list -func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper.Log) { - failureApiTokenRequestCount, _, err := GetApiTokenRequestCount(CtxApiTokenRequestFailureCount) +func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper.Log) { + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) if err != nil { log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) return @@ -243,10 +243,10 @@ func (c *ProviderConfig) HandleUnavailableApiToken(apiToken string, log wrapper. log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken) removeApiToken(ctxApiTokens, apiToken, log) addApiToken(ctxUnavailableApiTokens, apiToken, log) - ResetApiTokenRequestCount(CtxApiTokenRequestFailureCount, apiToken, log) + resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) } else { log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) - addApiTokenRequestCount(CtxApiTokenRequestFailureCount, apiToken, log) + addApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) } } @@ -340,7 +340,7 @@ func containsElement(slice []string, s string) bool { return false } -func GetApiTokenRequestCount(key string) (map[string]int64, uint32, error) { +func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { data, cas, err := proxywasm.GetSharedData(key) if err != nil { if errors.Is(err, types.ErrorStatusNotFound) { @@ -365,13 +365,17 @@ func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) { modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log) } -func ResetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { +func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log) } +func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) { + resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiTokenInUse, log) +} + func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) { for attempt := 1; attempt <= casMaxRetries; attempt++ { - apiTokenRequestCount, cas, err := GetApiTokenRequestCount(key) + apiTokenRequestCount, cas, err := getApiTokenRequestCount(key) if err != nil { log.Errorf("Failed to get %s: %v", key, err) continue @@ -423,10 +427,6 @@ func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { } } -func getApiTokenInUse(ctx wrapper.HttpContext) string { - return ctx.GetContext(ApiTokenInUse).(string) -} - func (c *ProviderConfig) IsFailoverEnabled() bool { return c.failover != nil && c.failover.enabled } @@ -436,5 +436,35 @@ func resetSharedData() { _ = proxywasm.SetSharedData(ctxApiTokens, nil, 0) _ = proxywasm.SetSharedData(ctxUnavailableApiTokens, nil, 0) _ = proxywasm.SetSharedData(ctxApiTokenRequestSuccessCount, nil, 0) - _ = proxywasm.SetSharedData(CtxApiTokenRequestFailureCount, nil, 0) + _ = proxywasm.SetSharedData(ctxApiTokenRequestFailureCount, nil, 0) +} + +func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { + // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. + if c.IsFailoverEnabled() && ctx.GetContext(apiTokenHealthCheck) == nil { + c.handleUnavailableApiToken(apiTokenInUse, log) + } +} + +func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { + return getApiTokenInUse(ctx) +} + +func getApiTokenInUse(ctx wrapper.HttpContext) string { + return ctx.GetContext(apiTokenInUse).(string) +} + +func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { + apiTokenInUse := c.GetRandomToken() + if c.IsFailoverEnabled() { + // Use the health check token if it is a health check request. + if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" { + apiTokenInUse = apiTokenHealthCheck + } else { + // if enable apiToken failover, only use available apiToken + apiTokenInUse = c.GetGlobalRandomToken(log) + } + } + log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiTokenInUse) + ctx.SetContext(apiTokenInUse, apiTokenInUse) } From bda87f14f74b69d60e20025828be435c7cf336dd Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 28 Sep 2024 22:07:48 +0800 Subject: [PATCH 14/31] fix --- .../extensions/ai-proxy/provider/failover.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index d99146ccfb..37e8bb4436 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -370,7 +370,13 @@ func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { } func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) { - resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiTokenInUse, log) + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("failed to get failureApiTokenRequestCount: %v", err) + } + if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { + resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiTokenInUse, log) + } } func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) { @@ -455,16 +461,16 @@ func getApiTokenInUse(ctx wrapper.HttpContext) string { } func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { - apiTokenInUse := c.GetRandomToken() + apiToken := c.GetRandomToken() if c.IsFailoverEnabled() { // Use the health check token if it is a health check request. if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" { - apiTokenInUse = apiTokenHealthCheck + apiToken = apiTokenHealthCheck } else { // if enable apiToken failover, only use available apiToken - apiTokenInUse = c.GetGlobalRandomToken(log) + apiToken = c.GetGlobalRandomToken(log) } } - log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiTokenInUse) - ctx.SetContext(apiTokenInUse, apiTokenInUse) + log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) + ctx.SetContext(apiTokenInUse, apiToken) } From 263c38cafb000589135ac11c64813dea513c0772 Mon Sep 17 00:00:00 2001 From: Se7en Date: Sat, 5 Oct 2024 17:10:42 +0800 Subject: [PATCH 15/31] config envoy local cluster and isolate apiToken ctx between different providers --- helm/core/templates/_pod.tpl | 6 - helm/core/templates/configmap.yaml | 31 ++- .../extensions/ai-proxy/config/config.go | 8 +- plugins/wasm-go/extensions/ai-proxy/main.go | 2 + .../extensions/ai-proxy/provider/ai360.go | 2 +- .../extensions/ai-proxy/provider/azure.go | 2 +- .../extensions/ai-proxy/provider/baichuan.go | 2 +- .../extensions/ai-proxy/provider/baidu.go | 2 +- .../extensions/ai-proxy/provider/claude.go | 2 +- .../ai-proxy/provider/cloudflare.go | 2 +- .../extensions/ai-proxy/provider/cohere.go | 2 +- .../extensions/ai-proxy/provider/deepl.go | 2 +- .../extensions/ai-proxy/provider/deepseek.go | 2 +- .../extensions/ai-proxy/provider/doubao.go | 2 +- .../extensions/ai-proxy/provider/failover.go | 199 +++++++++++------- .../extensions/ai-proxy/provider/gemini.go | 2 +- .../extensions/ai-proxy/provider/groq.go | 2 +- .../extensions/ai-proxy/provider/minimax.go | 2 +- .../extensions/ai-proxy/provider/mistral.go | 2 +- .../extensions/ai-proxy/provider/moonshot.go | 2 +- .../extensions/ai-proxy/provider/openai.go | 2 +- .../extensions/ai-proxy/provider/qwen.go | 2 +- .../extensions/ai-proxy/provider/spark.go | 2 +- .../extensions/ai-proxy/provider/stepfun.go | 2 +- .../extensions/ai-proxy/provider/yi.go | 2 +- .../extensions/ai-proxy/provider/zhipuai.go | 2 +- .../wasm-go/pkg/wrapper/cluster_wrapper.go | 6 +- 27 files changed, 184 insertions(+), 110 deletions(-) diff --git a/helm/core/templates/_pod.tpl b/helm/core/templates/_pod.tpl index 432f9d3d4e..ce17b31c85 100644 --- a/helm/core/templates/_pod.tpl +++ b/helm/core/templates/_pod.tpl @@ -123,10 +123,8 @@ template: - name: LITE_METRICS value: "on" {{- end }} - {{- if include "skywalking.enabled" . }} - name: ISTIO_BOOTSTRAP_OVERRIDE value: /etc/istio/custom-bootstrap/custom_bootstrap.json - {{- end }} {{- with .Values.gateway.networkGateway }} - name: ISTIO_META_REQUESTED_NETWORK_VIEW value: "{{.}}" @@ -188,10 +186,8 @@ template: mountPath: /etc/istio/pod - name: proxy-socket mountPath: /etc/istio/proxy - {{- if include "skywalking.enabled" . }} - mountPath: /etc/istio/custom-bootstrap name: custom-bootstrap-volume - {{- end }} {{- if .Values.global.volumeWasmPlugins }} - mountPath: /opt/plugins name: local-wasmplugins-volume @@ -276,12 +272,10 @@ template: - name: config configMap: name: higress-config - {{- if include "skywalking.enabled" . }} - configMap: defaultMode: 420 name: higress-custom-bootstrap name: custom-bootstrap-volume - {{- end }} - name: istio-data emptyDir: {} - name: proxy-socket diff --git a/helm/core/templates/configmap.yaml b/helm/core/templates/configmap.yaml index 9a07c3392c..456f0c521b 100644 --- a/helm/core/templates/configmap.yaml +++ b/helm/core/templates/configmap.yaml @@ -136,7 +136,6 @@ data: {{- include "mesh" . }} {{- end }} --- -{{- if include "skywalking.enabled" . }} apiVersion: v1 kind: ConfigMap metadata: @@ -147,6 +146,7 @@ metadata: data: custom_bootstrap.json: |- { + {{- if include "skywalking.enabled" . }} "stats_sinks": [ { "name": "envoy.metrics_service", @@ -161,8 +161,10 @@ data: } } ], + {{- end }} "static_resources": { "clusters": [ + {{- if include "skywalking.enabled" . }} { "name": "service_skywalking", "type": "LOGICAL_DNS", @@ -190,9 +192,32 @@ data: } ] } + }, + {{- end }} + { + "name": "higress-gateway-local", + "type": "STATIC", + "connect_timeout": "5s", + "load_assignment": { + "cluster_name": "higress-gateway-local", + "endpoints": [ + { + "lb_endpoints": [ + { + "endpoint": { + "address": { + "socket_address": { + "address": "127.0.0.1", + "port_value": {{ .Values.gateway.httpPort }} + } + } + } + } + ] + } + ] + } } ] } } ---- -{{- end }} diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index d604630e31..ab9962f525 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -80,14 +80,12 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { c.activeProvider = nil return nil } + var err error + c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig) providerConfig := c.GetProviderConfig() - err := providerConfig.SetApiTokensFailover(log) - if err != nil { - return err - } + err = providerConfig.SetApiTokensFailover(log) - c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig) return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 102cf015c6..5ac79ef40e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -94,6 +94,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf ctx.DisableReroute() // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) + // Set the request host and path to shared data in case they are needed in apiToken health check + providerConfig.SetRequestHostAndPath(log) hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 4e49f16df7..ae77649e51 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -49,7 +49,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam _ = util.OverwriteRequestHost(ai360Domain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", getApiTokenInUse(ctx)) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetApiTokenInUse(ctx)) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 2dcba2f8ff..959bd94061 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -57,7 +57,7 @@ func (m *azureProvider) GetProviderType() string { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = util.OverwriteRequestPath(m.serviceUrl.RequestURI()) _ = util.OverwriteRequestHost(m.serviceUrl.Host) - _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0]) + _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.GetApiTokenInUse(ctx)) if apiName == ApiNameChatCompletion { _ = proxywasm.RemoveHttpRequestHeader("Content-Length") } else { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index c8729ec474..af35f920d5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -49,7 +49,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(baichuanChatCompletionPath) _ = util.OverwriteRequestHost(baichuanDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 2f0774d9aa..d0197e4aa5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -232,7 +232,7 @@ func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel strin if !ok { suffix = baiduModel } - return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, getApiTokenInUse(ctx)) + return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx)) } func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 31e02709ea..90502a6605 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -108,7 +108,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestPath(claudeChatCompletionPath) _ = util.OverwriteRequestHost(claudeDomain) - _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", getApiTokenInUse(ctx)) + _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 798ae92466..b52f942b65 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -49,7 +49,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) _ = util.OverwriteRequestHost(cloudflareDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index aa6b422593..f233955afa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -59,7 +59,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } _ = util.OverwriteRequestHost(cohereDomain) _ = util.OverwriteRequestPath(chatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 32d90577f4..6ff536dd70 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -79,7 +79,7 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestPath(deeplChatCompletionPath) - _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") return types.HeaderStopIteration, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index e9e9c26fdc..ceb1eb1eba 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -49,7 +49,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(deepseekChatCompletionPath) _ = util.OverwriteRequestHost(deepseekDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 85ff3f4a8f..5e4d847be3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -42,7 +42,7 @@ func (m *doubaoProvider) GetProviderType() string { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = util.OverwriteRequestHost(doubaoDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 37e8bb4436..e16c577b65 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -29,6 +29,16 @@ type failover struct { healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` // @Title zh-CN 健康检测使用的模型 healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` + + ctxApiTokenInUse string + ctxApiTokenHealthCheck string + ctxHealthCheckHeader string + ctxApiTokenRequestFailureCount string + ctxApiTokenRequestSuccessCount string + ctxApiTokens string + ctxUnavailableApiTokens string + ctxRequestHostAndPath string + ctxVmLease string } type Lease struct { @@ -36,23 +46,22 @@ type Lease struct { Timestamp int64 `json:"timestamp"` } -var ( - healthCheckClient wrapper.HttpClient -) +type HostPath struct { + Host string `json:"host"` + Path string `json:"path"` +} const ( - apiTokenInUse = "apiTokenInUse" - apiTokenHealthCheck = "apiTokenHealthCheck" - vmLease = "vmLease" - ctxApiTokenRequestFailureCount = "apiTokenRequestFailureCount" - ctxApiTokenRequestSuccessCount = "apiTokenRequestSuccessCount" - ctxApiTokens = "apiTokens" - ctxUnavailableApiTokens = "unavailableApiTokens" casMaxRetries = 10 addApiTokenOperation = "addApiToken" removeApiTokenOperation = "removeApiToken" addApiTokenRequestCountOperation = "addApiTokenRequestCount" - resetApiTokenRequestCountOperation = "ResetApiTokenRequestCount" + resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" + higressGatewayLocalCluster = "higress-gateway-local" +) + +var ( + healthCheckClient wrapper.HttpClient ) func (f *failover) FromJson(json gjson.Result) { @@ -83,15 +92,25 @@ func (f *failover) Validate() error { return nil } +func (c *ProviderConfig) initVariable() { + // Set provider name as prefix to differentiate shared data + provider := c.GetType() + c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse" + c.failover.ctxApiTokenHealthCheck = provider + "-apiTokenHealthCheck" + c.failover.ctxHealthCheckHeader = provider + "-apiToken-health-check" + c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount" + c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount" + c.failover.ctxApiTokens = provider + "-apiTokens" + c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens" + c.failover.ctxRequestHostAndPath = provider + "-requestHostAndPath" + c.failover.ctxVmLease = provider + "-vmLease" +} + func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { // Reset shared data in case plugin configuration is updated - resetSharedData() - - // TODO: 目前需要手动加一个 cluster 指向本地的地址,健康检测需要访问该地址 - healthCheckClient = wrapper.NewClusterClient(wrapper.StaticIpCluster{ - ServiceName: "local_cluster", - Port: 10000, - }) + log.Debugf("Wasm plugin configuration is updated, reset shared data") + c.resetSharedData() + c.initVariable() vmID := generateVMID() err := c.initApiTokens() @@ -102,9 +121,9 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { if c.failover != nil && c.failover.enabled { wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() { // Only the Wasm VM that successfully acquires the lease will perform health check - if tryAcquireOrRenewLease(vmID, log) { - log.Debugf("Successfully acquired or renewed lease: %s", vmID) - unavailableTokens, _, err := getApiTokens(ctxUnavailableApiTokens) + if c.tryAcquireOrRenewLease(vmID, log) { + log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType()) + unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) if err != nil { log.Errorf("Failed to get unavailable tokens: %v", err) return @@ -112,22 +131,13 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { if len(unavailableTokens) > 0 { for _, apiToken := range unavailableTokens { log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) - - path := "/v1/chat/completions" - headers := [][2]string{ - {"Content-Type", "application/json"}, - {"ApiToken-Health-Check", apiToken}, - } - body := []byte(fmt.Sprintf(`{ - "model": "%s", - "messages": [ - { - "role": "user", - "content": "who are you?" - } - ] - }`, c.failover.healthCheckModel)) - err := healthCheckClient.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + hostPath, headers, body := c.generateRequestHeadersAndBody(apiToken, log) + fmt.Println("host", hostPath.Host, "path", hostPath.Path) + healthCheckClient = wrapper.NewClusterClient(wrapper.RouteCluster{ + Host: hostPath.Host, + Cluster: higressGatewayLocalCluster, + }) + err = healthCheckClient.Post(hostPath.Path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode == 200 { c.handleAvailableApiToken(apiToken, log) } @@ -143,20 +153,48 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { return nil } -func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { +func (c *ProviderConfig) generateRequestHeadersAndBody(apiToken string, log wrapper.Log) (HostPath, [][2]string, []byte) { + data, _, err := proxywasm.GetSharedData(c.failover.ctxRequestHostAndPath) + if err != nil { + log.Errorf("Failed to get request host and path: %v", err) + } + var hostPath HostPath + err = json.Unmarshal(data, &hostPath) + if err != nil { + log.Errorf("Failed to unmarshal request host and path: %v", err) + } + + headers := [][2]string{ + {"content-type", "application/json"}, + {c.failover.ctxHealthCheckHeader, apiToken}, + {":authority", hostPath.Host}, + } + body := []byte(fmt.Sprintf(`{ + "model": "%s", + "messages": [ + { + "role": "user", + "content": "who are you?" + } + ] + }`, c.failover.healthCheckModel)) + return hostPath, headers, body +} + +func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { now := time.Now().Unix() - data, cas, err := proxywasm.GetSharedData(vmLease) + data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease) if err != nil { if errors.Is(err, types.ErrorStatusNotFound) { - return setLease(vmID, now, cas, log) + return c.setLease(vmID, now, cas, log) } else { log.Errorf("Failed to get lease: %v", err) return false } } if data == nil { - return setLease(vmID, now, cas, log) + return c.setLease(vmID, now, cas, log) } var lease Lease @@ -170,13 +208,13 @@ func tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { if lease.VMID == vmID || now-lease.Timestamp > 60 { lease.VMID = vmID lease.Timestamp = now - return setLease(vmID, now, cas, log) + return c.setLease(vmID, now, cas, log) } return false } -func setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { +func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { lease := Lease{ VMID: vmID, Timestamp: timestamp, @@ -187,7 +225,7 @@ func setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { return false } - if err := proxywasm.SetSharedData(vmLease, leaseByte, cas); err != nil { + if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil { log.Errorf("Failed to set or renew lease: %v", err) return false } @@ -201,7 +239,7 @@ func generateVMID() string { // When number of request successes exceeds the threshold during health check, // add the apiToken back to the available list and remove it from the unavailable list func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) { - successApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestSuccessCount) + successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount) if err != nil { log.Errorf("Failed to get successApiTokenRequestCount: %v", err) return @@ -210,25 +248,25 @@ func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Lo successCount := successApiTokenRequestCount[apiToken] + 1 if successCount >= c.failover.successThreshold { log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken) - removeApiToken(ctxUnavailableApiTokens, apiToken, log) - addApiToken(ctxApiTokens, apiToken, log) - resetApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) + removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) + addApiToken(c.failover.ctxApiTokens, apiToken, log) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) } else { log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check......", apiToken, successCount) - addApiTokenRequestCount(ctxApiTokenRequestSuccessCount, apiToken, log) + addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) } } // When number of request failures exceeds the threshold, // remove the apiToken from the available list and add it to the unavailable list func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper.Log) { - failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) if err != nil { log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) return } - availableTokens, _, err := getApiTokens(ctxApiTokens) + availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens) if err != nil { log.Errorf("Failed to get available apiToken: %v", err) return @@ -241,12 +279,12 @@ func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper. failureCount := failureApiTokenRequestCount[apiToken] + 1 if failureCount >= c.failover.failureThreshold { log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken) - removeApiToken(ctxApiTokens, apiToken, log) - addApiToken(ctxUnavailableApiTokens, apiToken, log) - resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) + removeApiToken(c.failover.ctxApiTokens, apiToken, log) + addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) } else { log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) - addApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiToken, log) + addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) } } @@ -370,12 +408,12 @@ func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { } func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) { - failureApiTokenRequestCount, _, err := getApiTokenRequestCount(ctxApiTokenRequestFailureCount) + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) if err != nil { log.Errorf("failed to get failureApiTokenRequestCount: %v", err) } if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { - resetApiTokenRequestCount(ctxApiTokenRequestFailureCount, apiTokenInUse, log) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) } } @@ -411,12 +449,13 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log } func (c *ProviderConfig) initApiTokens() error { - return setApiTokens(ctxApiTokens, c.apiTokens, 0) + return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0) } func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { - apiTokens, _, err := getApiTokens(ctxApiTokens) - unavailableApiTokens, _, err := getApiTokens(ctxUnavailableApiTokens) + fmt.Println(c.apiTokens) + apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens) + unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens) if err != nil { @@ -437,34 +476,30 @@ func (c *ProviderConfig) IsFailoverEnabled() bool { return c.failover != nil && c.failover.enabled } -func resetSharedData() { - _ = proxywasm.SetSharedData(vmLease, nil, 0) - _ = proxywasm.SetSharedData(ctxApiTokens, nil, 0) - _ = proxywasm.SetSharedData(ctxUnavailableApiTokens, nil, 0) - _ = proxywasm.SetSharedData(ctxApiTokenRequestSuccessCount, nil, 0) - _ = proxywasm.SetSharedData(ctxApiTokenRequestFailureCount, nil, 0) +func (c *ProviderConfig) resetSharedData() { + _ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) } func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. - if c.IsFailoverEnabled() && ctx.GetContext(apiTokenHealthCheck) == nil { + if c.IsFailoverEnabled() && ctx.GetContext(c.failover.ctxApiTokenHealthCheck) == nil { c.handleUnavailableApiToken(apiTokenInUse, log) } } func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { - return getApiTokenInUse(ctx) -} - -func getApiTokenInUse(ctx wrapper.HttpContext) string { - return ctx.GetContext(apiTokenInUse).(string) + return ctx.GetContext(c.failover.ctxApiTokenInUse).(string) } func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { apiToken := c.GetRandomToken() if c.IsFailoverEnabled() { // Use the health check token if it is a health check request. - if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" { + if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader(c.failover.ctxHealthCheckHeader); apiTokenHealthCheck != "" { apiToken = apiTokenHealthCheck } else { // if enable apiToken failover, only use available apiToken @@ -472,5 +507,21 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L } } log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) - ctx.SetContext(apiTokenInUse, apiToken) + ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) +} + +func (c *ProviderConfig) SetRequestHostAndPath(log wrapper.Log) { + hostPath := HostPath{ + Host: wrapper.GetRequestHost(), + Path: wrapper.GetRequestPath(), + } + hostPathByte, err := json.Marshal(hostPath) + if err != nil { + log.Errorf("Failed to marshal request host and path: %v", err) + + } + err = proxywasm.SetSharedData(c.failover.ctxRequestHostAndPath, hostPathByte, 0) + if err != nil { + log.Errorf("Failed to set request host and path: %v", err) + } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 4a6683f319..b9c157d864 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -52,7 +52,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa return types.ActionContinue, errUnsupportedApiName } - _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, getApiTokenInUse(ctx)) + _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) _ = util.OverwriteRequestHost(geminiDomain) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 2883e9ee92..e97c92c330 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -47,7 +47,7 @@ func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } _ = util.OverwriteRequestPath(groqChatCompletionPath) _ = util.OverwriteRequestHost(groqDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index d3c912bce7..62173deeb3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -79,7 +79,7 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(minimaxDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") // Delay the header processing to allow changing streaming mode in OnRequestBody diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index bfcf0a97da..57c4b5ae32 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -44,7 +44,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(mistralDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 9eac6788ba..e76bcf529d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -60,7 +60,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } _ = util.OverwriteRequestPath(moonshotChatCompletionPath) _ = util.OverwriteRequestHost(moonshotDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index f5ed9270f1..42d991737f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -74,7 +74,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa _ = util.OverwriteRequestHost(m.customDomain) } if len(m.config.apiTokens) > 0 { - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) } _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index f58b8325aa..0f6b533164 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -69,7 +69,7 @@ func (m *qwenProvider) GetProviderType() string { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = util.OverwriteRequestHost(qwenDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 6f044f8e4f..67c1e001bd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -73,7 +73,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } _ = util.OverwriteRequestHost(sparkHost) _ = util.OverwriteRequestPath(sparkChatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index fe445fa8c8..c2a97e7ab3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -47,7 +47,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(stepfunChatCompletionPath) _ = util.OverwriteRequestHost(stepfunDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 4631793a5c..760fc68753 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -47,7 +47,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, } _ = util.OverwriteRequestPath(yiChatCompletionPath) _ = util.OverwriteRequestHost(yiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 8cebc247e1..a5175dbc66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -46,7 +46,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) _ = util.OverwriteRequestHost(zhipuAiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + getApiTokenInUse(ctx)) + _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go index 96600192b1..d07fe12095 100644 --- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go @@ -27,10 +27,14 @@ type Cluster interface { } type RouteCluster struct { - Host string + Host string + Cluster string } func (c RouteCluster) ClusterName() string { + if c.Cluster != "" { + return c.Cluster + } routeName, err := proxywasm.GetProperty([]string{"cluster_name"}) if err != nil { proxywasm.LogErrorf("get route cluster failed, err:%v", err) From 374d5be733bb7f5a664c7a91e8cd40f7786f9b0e Mon Sep 17 00:00:00 2001 From: Se7en Date: Mon, 7 Oct 2024 16:50:11 +0800 Subject: [PATCH 16/31] update README.md --- plugins/wasm-go/extensions/ai-proxy/README.md | 29 +++++++++++++------ .../extensions/ai-proxy/provider/failover.go | 5 ++-- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index c63d485ca7..26083d5cc5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -31,15 +31,16 @@ description: AI 代理插件配置参考 `provider`的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `type` | string | 必填 | - | AI 服务提供商名称 | -| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | -| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | -| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | -| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | -| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | -| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | string | 必填 | - | AI 服务提供商名称 | +| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | +| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | +| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | +| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | +| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | +| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | `context`的配置字段说明如下: @@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字 如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。 对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。 +`failover` 的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------|-----------------|------|-------|-----------------------| +| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | +| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值 | +| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值 | +| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | +| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | +| healthCheckModel | int | 必填 | | 健康检测使用的模型 | ### 提供商特有配置 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index e16c577b65..e6fdd61b7c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -19,7 +19,7 @@ import ( type failover struct { // @Title zh-CN 是否启用 apiToken 的 failover 机制 enabled bool `required:"true" yaml:"enabled" json:"enabled"` - // @Title zh-CN 触发 failover 的失败阈值 + // @Title zh-CN 触发 failover 连续请求失败的阈值 failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"` // @Title zh-CN 健康检测的成功阈值 successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"` @@ -28,8 +28,7 @@ type failover struct { // @Title zh-CN 健康检测的超时时间,单位毫秒 healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` // @Title zh-CN 健康检测使用的模型 - healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` - + healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` ctxApiTokenInUse string ctxApiTokenHealthCheck string ctxHealthCheckHeader string From fd49f2d89d0a5935d1671fdc99345e54da812734 Mon Sep 17 00:00:00 2001 From: Se7en Date: Mon, 7 Oct 2024 17:14:30 +0800 Subject: [PATCH 17/31] add description --- plugins/wasm-go/extensions/ai-proxy/main.go | 2 +- .../extensions/ai-proxy/provider/failover.go | 33 ++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 5ac79ef40e..8715f73fb4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -172,7 +172,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() - providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) + providerConfig.OnRequestFailed(apiTokenInUse, log) return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index e6fdd61b7c..cf847f9088 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -28,16 +28,23 @@ type failover struct { // @Title zh-CN 健康检测的超时时间,单位毫秒 healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` // @Title zh-CN 健康检测使用的模型 - healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` - ctxApiTokenInUse string - ctxApiTokenHealthCheck string - ctxHealthCheckHeader string + healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` + // @Title zh-CN 本次请求使用的 apiToken + ctxApiTokenInUse string + // @Title zh-CN 标记请求是否是健康检测请求 + ctxHealthCheckHeader string + // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数 ctxApiTokenRequestFailureCount string + // @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数 ctxApiTokenRequestSuccessCount string - ctxApiTokens string - ctxUnavailableApiTokens string - ctxRequestHostAndPath string - ctxVmLease string + // @Title zh-CN 记录所有可用的 apiToken 列表 + ctxApiTokens string + // @Title zh-CN 记录所有不可用的 apiToken 列表 + ctxUnavailableApiTokens string + // @Title zh-CN 记录请求的 host 和 path,用于在健康检测时构建请求 + ctxRequestHostAndPath string + // @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测 + ctxVmLease string } type Lease struct { @@ -95,7 +102,6 @@ func (c *ProviderConfig) initVariable() { // Set provider name as prefix to differentiate shared data provider := c.GetType() c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse" - c.failover.ctxApiTokenHealthCheck = provider + "-apiTokenHealthCheck" c.failover.ctxHealthCheckHeader = provider + "-apiToken-health-check" c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount" c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount" @@ -131,7 +137,6 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { for _, apiToken := range unavailableTokens { log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) hostPath, headers, body := c.generateRequestHeadersAndBody(apiToken, log) - fmt.Println("host", hostPath.Host, "path", hostPath.Path) healthCheckClient = wrapper.NewClusterClient(wrapper.RouteCluster{ Host: hostPath.Host, Cluster: higressGatewayLocalCluster, @@ -251,7 +256,7 @@ func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Lo addApiToken(c.failover.ctxApiTokens, apiToken, log) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) } else { - log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check......", apiToken, successCount) + log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount) addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) } } @@ -452,7 +457,6 @@ func (c *ProviderConfig) initApiTokens() error { } func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { - fmt.Println(c.apiTokens) apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens) unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens) @@ -483,9 +487,8 @@ func (c *ProviderConfig) resetSharedData() { _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) } -func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { - // If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken. - if c.IsFailoverEnabled() && ctx.GetContext(c.failover.ctxApiTokenHealthCheck) == nil { +func (c *ProviderConfig) OnRequestFailed(apiTokenInUse string, log wrapper.Log) { + if c.IsFailoverEnabled() { c.handleUnavailableApiToken(apiTokenInUse, log) } } From 66c371b0c2982f0a9dc666046f06c40f109db34d Mon Sep 17 00:00:00 2001 From: Se7en Date: Mon, 7 Oct 2024 18:48:37 +0800 Subject: [PATCH 18/31] fix nil point exception when don't set failover config --- .../extensions/ai-proxy/provider/failover.go | 67 ++++++++++--------- .../extensions/ai-proxy/provider/provider.go | 4 +- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index cf847f9088..fba144e757 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -112,21 +112,24 @@ func (c *ProviderConfig) initVariable() { } func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { + c.initVariable() // Reset shared data in case plugin configuration is updated - log.Debugf("Wasm plugin configuration is updated, reset shared data") + log.Debugf("ai-proxy plugin configuration is updated, reset shared data") c.resetSharedData() - c.initVariable() - vmID := generateVMID() - err := c.initApiTokens() - if err != nil { - return fmt.Errorf("failed to init apiTokens: %v", err) - } + if c.isFailoverEnabled() { + log.Debugf("ai-proxy plugin failover is enabled") + + vmID := generateVMID() + err := c.initApiTokens() + + if err != nil { + return fmt.Errorf("failed to init apiTokens: %v", err) + } - if c.failover != nil && c.failover.enabled { wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() { // Only the Wasm VM that successfully acquires the lease will perform health check - if c.tryAcquireOrRenewLease(vmID, log) { + if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) { log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType()) unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) if err != nil { @@ -412,12 +415,14 @@ func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { } func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) { - failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) - if err != nil { - log.Errorf("failed to get failureApiTokenRequestCount: %v", err) - } - if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { - resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) + if c.isFailoverEnabled() { + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("failed to get failureApiTokenRequestCount: %v", err) + } + if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) + } } } @@ -475,8 +480,8 @@ func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { } } -func (c *ProviderConfig) IsFailoverEnabled() bool { - return c.failover != nil && c.failover.enabled +func (c *ProviderConfig) isFailoverEnabled() bool { + return c.failover.enabled } func (c *ProviderConfig) resetSharedData() { @@ -488,7 +493,7 @@ func (c *ProviderConfig) resetSharedData() { } func (c *ProviderConfig) OnRequestFailed(apiTokenInUse string, log wrapper.Log) { - if c.IsFailoverEnabled() { + if c.isFailoverEnabled() { c.handleUnavailableApiToken(apiTokenInUse, log) } } @@ -499,7 +504,7 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { apiToken := c.GetRandomToken() - if c.IsFailoverEnabled() { + if c.isFailoverEnabled() { // Use the health check token if it is a health check request. if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader(c.failover.ctxHealthCheckHeader); apiTokenHealthCheck != "" { apiToken = apiTokenHealthCheck @@ -513,17 +518,19 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L } func (c *ProviderConfig) SetRequestHostAndPath(log wrapper.Log) { - hostPath := HostPath{ - Host: wrapper.GetRequestHost(), - Path: wrapper.GetRequestPath(), - } - hostPathByte, err := json.Marshal(hostPath) - if err != nil { - log.Errorf("Failed to marshal request host and path: %v", err) + if c.isFailoverEnabled() { + hostPath := HostPath{ + Host: wrapper.GetRequestHost(), + Path: wrapper.GetRequestPath(), + } + hostPathByte, err := json.Marshal(hostPath) + if err != nil { + log.Errorf("Failed to marshal request host and path: %v", err) - } - err = proxywasm.SetSharedData(c.failover.ctxRequestHostAndPath, hostPathByte, 0) - if err != nil { - log.Errorf("Failed to set request host and path: %v", err) + } + err = proxywasm.SetSharedData(c.failover.ctxRequestHostAndPath, hostPathByte, 0) + if err != nil { + log.Errorf("Failed to set request host and path: %v", err) + } } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 1d9fc0d4ca..33809447ea 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -289,8 +289,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } failoverJson := json.Get("failover") + c.failover = &failover{ + enabled: false, + } if failoverJson.Exists() { - c.failover = &failover{} c.failover.FromJson(failoverJson) } } From 7f36c0942a241b564bbb739389c8441ac4af766d Mon Sep 17 00:00:00 2001 From: Se7en Date: Mon, 7 Oct 2024 19:02:48 +0800 Subject: [PATCH 19/31] support github provider --- plugins/wasm-go/extensions/ai-proxy/provider/github.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 5ee51b2742..c1bd1b619d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -57,7 +57,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken()) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetApiTokenInUse(ctx)) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } From 01b92d85851d4c415003ac13747785cb37cfe2ef Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 10 Oct 2024 21:02:41 +0800 Subject: [PATCH 20/31] fix --- plugins/wasm-go/extensions/ai-proxy/provider/provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index fef1d9fffc..39207afe9a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -313,7 +313,7 @@ func (c *ProviderConfig) Validate() error { } } - if c.failover != nil { + if c.failover.enabled { if err := c.failover.Validate(); err != nil { return err } From 01b0eec6a79e2fa1dcf399f27d898bed268a1b14 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 17 Oct 2024 12:32:47 +0800 Subject: [PATCH 21/31] unified the transformation of HTTP headers and body for ai-proxy and http call --- helm/core/templates/_pod.tpl | 6 + helm/core/templates/configmap.yaml | 35 +----- .../extensions/ai-proxy/config/config.go | 2 +- plugins/wasm-go/extensions/ai-proxy/main.go | 2 - .../extensions/ai-proxy/provider/ai360.go | 11 ++ .../extensions/ai-proxy/provider/azure.go | 11 ++ .../extensions/ai-proxy/provider/baichuan.go | 11 ++ .../extensions/ai-proxy/provider/baidu.go | 11 ++ .../extensions/ai-proxy/provider/claude.go | 94 ++++++---------- .../ai-proxy/provider/cloudflare.go | 11 ++ .../extensions/ai-proxy/provider/cohere.go | 11 ++ .../extensions/ai-proxy/provider/deepl.go | 11 ++ .../extensions/ai-proxy/provider/deepseek.go | 11 ++ .../extensions/ai-proxy/provider/doubao.go | 11 ++ .../extensions/ai-proxy/provider/failover.go | 104 ++++++++++-------- .../extensions/ai-proxy/provider/gemini.go | 11 ++ .../extensions/ai-proxy/provider/github.go | 11 ++ .../extensions/ai-proxy/provider/groq.go | 65 ++++++----- .../extensions/ai-proxy/provider/hunyuan.go | 11 ++ .../extensions/ai-proxy/provider/minimax.go | 11 ++ .../extensions/ai-proxy/provider/mistral.go | 11 ++ .../extensions/ai-proxy/provider/moonshot.go | 10 ++ .../extensions/ai-proxy/provider/ollama.go | 11 ++ .../extensions/ai-proxy/provider/openai.go | 11 ++ .../extensions/ai-proxy/provider/provider.go | 20 ++++ .../extensions/ai-proxy/provider/qwen.go | 11 ++ .../ai-proxy/provider/request_helper.go | 9 ++ .../extensions/ai-proxy/provider/spark.go | 11 ++ .../extensions/ai-proxy/provider/stepfun.go | 11 ++ .../extensions/ai-proxy/provider/yi.go | 11 ++ .../extensions/ai-proxy/provider/zhipuai.go | 11 ++ .../wasm-go/extensions/ai-proxy/util/http.go | 62 ++++++++++- .../wasm-go/pkg/wrapper/cluster_wrapper.go | 19 +++- 33 files changed, 486 insertions(+), 173 deletions(-) diff --git a/helm/core/templates/_pod.tpl b/helm/core/templates/_pod.tpl index 9bc7b744b2..4e7e0a6ac7 100644 --- a/helm/core/templates/_pod.tpl +++ b/helm/core/templates/_pod.tpl @@ -123,8 +123,10 @@ template: - name: LITE_METRICS value: "on" {{- end }} + {{- if include "skywalking.enabled" . }} - name: ISTIO_BOOTSTRAP_OVERRIDE value: /etc/istio/custom-bootstrap/custom_bootstrap.json + {{- end }} {{- with .Values.gateway.networkGateway }} - name: ISTIO_META_REQUESTED_NETWORK_VIEW value: "{{.}}" @@ -186,8 +188,10 @@ template: mountPath: /etc/istio/pod - name: proxy-socket mountPath: /etc/istio/proxy + {{- if include "skywalking.enabled" . }} - mountPath: /etc/istio/custom-bootstrap name: custom-bootstrap-volume + {{- end }} {{- if .Values.global.volumeWasmPlugins }} - mountPath: /opt/plugins name: local-wasmplugins-volume @@ -272,10 +276,12 @@ template: - name: config configMap: name: higress-config + {{- if include "skywalking.enabled" . }} - configMap: defaultMode: 420 name: higress-custom-bootstrap name: custom-bootstrap-volume + {{- end }} - name: istio-data emptyDir: {} - name: proxy-socket diff --git a/helm/core/templates/configmap.yaml b/helm/core/templates/configmap.yaml index 2d7ae5b176..b7814f5bf7 100644 --- a/helm/core/templates/configmap.yaml +++ b/helm/core/templates/configmap.yaml @@ -136,6 +136,7 @@ data: {{- include "mesh" . }} {{- end }} --- +{{- if include "skywalking.enabled" . }} apiVersion: v1 kind: ConfigMap metadata: @@ -146,7 +147,6 @@ metadata: data: custom_bootstrap.json: |- { - {{- if include "skywalking.enabled" . }} "stats_sinks": [ { "name": "envoy.metrics_service", @@ -160,34 +160,7 @@ data: } } } - ], - {{- end }} - "static_resources": { - "clusters": [ - { - "name": "higress-gateway-local", - "type": "STATIC", - "connect_timeout": "5s", - "load_assignment": { - "cluster_name": "higress-gateway-local", - "endpoints": [ - { - "lb_endpoints": [ - { - "endpoint": { - "address": { - "socket_address": { - "address": "127.0.0.1", - "port_value": {{ .Values.gateway.httpPort }} - } - } - } - } - ] - } - ] - } - } - ] - } + ] } +--- +{{- end }} diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index ab9962f525..48f08dd9e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -84,7 +84,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig) providerConfig := c.GetProviderConfig() - err = providerConfig.SetApiTokensFailover(log) + err = providerConfig.SetApiTokensFailover(log, c.activeProvider) return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 8715f73fb4..532fee8259 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -94,8 +94,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf ctx.DisableReroute() // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) - // Set the request host and path to shared data in case they are needed in apiToken health check - providerConfig.SetRequestHostAndPath(log) hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index ae77649e51..f17052e3ae 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -24,6 +25,16 @@ type ai360Provider struct { contextCache *contextCache } +func (m *ai360Provider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *ai360Provider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *ai360ProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 959bd94061..a1f1af2301 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "net/url" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -50,6 +51,16 @@ type azureProvider struct { serviceUrl *url.URL } +func (m *azureProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *azureProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *azureProvider) GetProviderType() string { return providerTypeAzure } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index af35f920d5..25d063edc6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -39,6 +40,16 @@ type baichuanProvider struct { contextCache *contextCache } +func (m *baichuanProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *baichuanProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *baichuanProvider) GetProviderType() string { return providerTypeBaichuan } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index d0197e4aa5..0edb50e6da 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -52,6 +53,16 @@ type baiduProvider struct { contextCache *contextCache } +func (b *baiduProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (b *baiduProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (b *baiduProvider) GetProviderType() string { return providerTypeBaidu } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 90502a6605..26a5fc914a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -106,18 +107,26 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(claudeChatCompletionPath) - _ = util.OverwriteRequestHost(claudeDomain) - _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetApiTokenInUse(ctx)) + originalHeaders := util.GetOriginaHttplHeaders() + c.TransformRequestHeaders(originalHeaders, ctx, log) + util.ReplaceOriginalHttpHeaders(originalHeaders) + + return types.ActionContinue, nil +} + +func (c *claudeProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + util.OverwriteHttpRequestPath(headers, claudeChatCompletionPath) + util.OverwriteHttpRequestHost(headers, claudeDomain) + + headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion } - _ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - return types.ActionContinue, nil + headers.Add("anthropic-version", c.config.claudeVersion) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -135,72 +144,41 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if err := json.Unmarshal(body, request); err != nil { return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) } + return types.ActionContinue, nil + } - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } + // use openai protocol + modifiedBody, err := c.TransformRequestBody(body, ctx, log) + if err != nil { + return types.ActionContinue, err + } + err = replaceHttpJsonRequestBody(modifiedBody, log) + if err != nil { return types.ActionContinue, err } - // use openai protocol + return types.ActionContinue, nil + +} + +func (c *claudeProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") + err := c.config.setRequestModel(ctx, request, log) + if err != nil { + return nil, err } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, c.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) streaming := request.Stream if streaming { _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") } - if c.config.context == nil { - claudeRequest := c.buildClaudeTextGenRequest(request) - return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log) - } - - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - claudeRequest := c.buildClaudeTextGenRequest(request) - if err := replaceJsonRequestBody(claudeRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + claudeRequest := c.buildClaudeTextGenRequest(request) + return json.Marshal(claudeRequest) } func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index b52f942b65..d0fdbad1bf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -39,6 +40,16 @@ type cloudflareProvider struct { contextCache *contextCache } +func (c *cloudflareProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (c *cloudflareProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (c *cloudflareProvider) GetProviderType() string { return providerTypeCloudflare } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index f233955afa..e530a6ef9e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -35,6 +36,16 @@ type cohereProvider struct { config ProviderConfig } +func (m *cohereProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *cohereProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + type cohereTextGenRequest struct { Message string `json:"message,omitempty"` Model string `json:"model,omitempty"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 6ff536dd70..786d635fec 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -27,6 +28,16 @@ type deeplProvider struct { contextCache *contextCache } +func (d *deeplProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (d *deeplProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + // spec reference: https://developers.deepl.com/docs/v/zh/api-reference/translate/openapi-spec-for-text-translation type deeplRequest struct { // "Model" parameter is used to distinguish which service to use diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index ceb1eb1eba..331f7dd019 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -39,6 +40,16 @@ type deepseekProvider struct { contextCache *contextCache } +func (m *deepseekProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *deepseekProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *deepseekProvider) GetProviderType() string { return providerTypeDeepSeek } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 5e4d847be3..1e0db5b441 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -36,6 +37,16 @@ type doubaoProvider struct { contextCache *contextCache } +func (m *doubaoProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *doubaoProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *doubaoProvider) GetProviderType() string { return providerTypeDoubao } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index fba144e757..364edd1e03 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/google/uuid" "math/rand" "net/http" @@ -31,8 +32,6 @@ type failover struct { healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` // @Title zh-CN 本次请求使用的 apiToken ctxApiTokenInUse string - // @Title zh-CN 标记请求是否是健康检测请求 - ctxHealthCheckHeader string // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数 ctxApiTokenRequestFailureCount string // @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数 @@ -41,8 +40,8 @@ type failover struct { ctxApiTokens string // @Title zh-CN 记录所有不可用的 apiToken 列表 ctxUnavailableApiTokens string - // @Title zh-CN 记录请求的 host 和 path,用于在健康检测时构建请求 - ctxRequestHostAndPath string + // @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求 + ctxHealthCheckEndpoint string // @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测 ctxVmLease string } @@ -52,9 +51,10 @@ type Lease struct { Timestamp int64 `json:"timestamp"` } -type HostPath struct { - Host string `json:"host"` - Path string `json:"path"` +type HealthCheckEndpoint struct { + Host string `json:"host"` + Path string `json:"path"` + Cluster string `json:"cluster"` } const ( @@ -63,7 +63,6 @@ const ( removeApiTokenOperation = "removeApiToken" addApiTokenRequestCountOperation = "addApiTokenRequestCount" resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" - higressGatewayLocalCluster = "higress-gateway-local" ) var ( @@ -102,16 +101,19 @@ func (c *ProviderConfig) initVariable() { // Set provider name as prefix to differentiate shared data provider := c.GetType() c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse" - c.failover.ctxHealthCheckHeader = provider + "-apiToken-health-check" c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount" c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount" c.failover.ctxApiTokens = provider + "-apiTokens" c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens" - c.failover.ctxRequestHostAndPath = provider + "-requestHostAndPath" + c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath" c.failover.ctxVmLease = provider + "-vmLease" } -func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { +func parseConfig(json gjson.Result, config *any, log wrapper.Log) error { + return nil +} + +func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error { c.initVariable() // Reset shared data in case plugin configuration is updated log.Debugf("ai-proxy plugin configuration is updated, reset shared data") @@ -139,12 +141,24 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { if len(unavailableTokens) > 0 { for _, apiToken := range unavailableTokens { log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) - hostPath, headers, body := c.generateRequestHeadersAndBody(apiToken, log) - healthCheckClient = wrapper.NewClusterClient(wrapper.RouteCluster{ - Host: hostPath.Host, - Cluster: higressGatewayLocalCluster, + healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log) + healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{ + Host: healthCheckEndpoint.Host, + Cluster: healthCheckEndpoint.Cluster, }) - err = healthCheckClient.Post(hostPath.Path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + setParseConfig := wrapper.ParseConfigBy[any](parseConfig) + vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig) + pluginCtx := vmCtx.NewPluginContext(rand.Uint32()) + ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any]) + ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) + + originalHeaders := util.SliceToHeader(headers) + activeProvider.TransformRequestHeaders(originalHeaders, ctx, log) + modifiedHeaders := util.HeaderToSlice(originalHeaders) + modifiedBody, _ := activeProvider.TransformRequestBody(body, ctx, log) + + err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode == 200 { c.handleAvailableApiToken(apiToken, log) } @@ -160,21 +174,19 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log) error { return nil } -func (c *ProviderConfig) generateRequestHeadersAndBody(apiToken string, log wrapper.Log) (HostPath, [][2]string, []byte) { - data, _, err := proxywasm.GetSharedData(c.failover.ctxRequestHostAndPath) +func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) { + data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint) if err != nil { log.Errorf("Failed to get request host and path: %v", err) } - var hostPath HostPath - err = json.Unmarshal(data, &hostPath) + var healthCheckEndpoint HealthCheckEndpoint + err = json.Unmarshal(data, &healthCheckEndpoint) if err != nil { log.Errorf("Failed to unmarshal request host and path: %v", err) } headers := [][2]string{ {"content-type", "application/json"}, - {c.failover.ctxHealthCheckHeader, apiToken}, - {":authority", hostPath.Host}, } body := []byte(fmt.Sprintf(`{ "model": "%s", @@ -185,7 +197,7 @@ func (c *ProviderConfig) generateRequestHeadersAndBody(apiToken string, log wrap } ] }`, c.failover.healthCheckModel)) - return hostPath, headers, body + return healthCheckEndpoint, headers, body } func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { @@ -289,6 +301,8 @@ func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper. removeApiToken(c.failover.ctxApiTokens, apiToken, log) addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) + // Set the request host and path to shared data in case they are needed in apiToken health check + c.setHealthCheckEndpoint(log) } else { log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) @@ -503,34 +517,34 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { } func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { - apiToken := c.GetRandomToken() + var apiToken string if c.isFailoverEnabled() { - // Use the health check token if it is a health check request. - if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader(c.failover.ctxHealthCheckHeader); apiTokenHealthCheck != "" { - apiToken = apiTokenHealthCheck - } else { - // if enable apiToken failover, only use available apiToken - apiToken = c.GetGlobalRandomToken(log) - } + // if enable apiToken failover, only use available apiToken + apiToken = c.GetGlobalRandomToken(log) + } else { + apiToken = c.GetRandomToken() } log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) } -func (c *ProviderConfig) SetRequestHostAndPath(log wrapper.Log) { - if c.isFailoverEnabled() { - hostPath := HostPath{ - Host: wrapper.GetRequestHost(), - Path: wrapper.GetRequestPath(), - } - hostPathByte, err := json.Marshal(hostPath) - if err != nil { - log.Errorf("Failed to marshal request host and path: %v", err) +func (c *ProviderConfig) setHealthCheckEndpoint(log wrapper.Log) { + cluster, err := proxywasm.GetProperty([]string{"cluster_name"}) + if err != nil { + log.Errorf("Failed to get cluster_name: %v", err) + } + hostPath := HealthCheckEndpoint{ + Host: wrapper.GetRequestHost(), + Path: wrapper.GetRequestPath(), + Cluster: string(cluster), + } + hostPathByte, err := json.Marshal(hostPath) + if err != nil { + log.Errorf("Failed to marshal request host and path: %v", err) - } - err = proxywasm.SetSharedData(c.failover.ctxRequestHostAndPath, hostPathByte, 0) - if err != nil { - log.Errorf("Failed to set request host and path: %v", err) - } + } + err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, hostPathByte, 0) + if err != nil { + log.Errorf("Failed to set request host and path: %v", err) } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index b9c157d864..b0563dc926 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -43,6 +44,16 @@ type geminiProvider struct { contextCache *contextCache } +func (g *geminiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (g *geminiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (g *geminiProvider) GetProviderType() string { return providerTypeGemini } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index c1bd1b619d..ce9ea75bb6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -26,6 +27,16 @@ type githubProvider struct { contextCache *contextCache } +func (m *githubProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *githubProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *githubProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index e97c92c330..c2a67abf62 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -1,12 +1,12 @@ package provider import ( + "encoding/json" "errors" - "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -18,14 +18,14 @@ const ( type groqProviderInitializer struct{} -func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { +func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") } return nil } -func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &groqProvider{ config: config, contextCache: createContextCache(&config), @@ -37,47 +37,52 @@ type groqProvider struct { contextCache *contextCache } -func (m *groqProvider) GetProviderType() string { +func (g *groqProvider) GetProviderType() string { return providerTypeGroq } -func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(groqChatCompletionPath) - _ = util.OverwriteRequestHost(groqDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + originalHeaders := util.GetOriginaHttplHeaders() + g.TransformRequestHeaders(originalHeaders, ctx, log) + util.ReplaceOriginalHttpHeaders(originalHeaders) return types.ActionContinue, nil } -func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil + modifiedBody, err := g.TransformRequestBody(body, ctx, log) + if err != nil { + return types.ActionContinue, err + } + err = replaceHttpJsonRequestBody(modifiedBody, log) + if err != nil { + return types.ActionContinue, err } + return types.ActionContinue, nil +} + +func (g *groqProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + util.OverwriteHttpRequestPath(headers, groqChatCompletionPath) + util.OverwriteHttpRequestHost(headers, groqDomain) + util.OverwriteHttpRequestAuthorization(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (g *groqProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + + err := g.config.setRequestModel(ctx, request, log) + if err != nil { + return nil, err } - return types.ActionContinue, err + + return json.Marshal(request) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 7640a380b3..757447dcf0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -109,6 +110,16 @@ type hunyuanProvider struct { contextCache *contextCache } +func (m *hunyuanProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *hunyuanProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 62173deeb3..83c7952867 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -70,6 +71,16 @@ type minimaxProvider struct { contextCache *contextCache } +func (m *minimaxProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *minimaxProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *minimaxProvider) GetProviderType() string { return providerTypeMinimax } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 57c4b5ae32..343bd483e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -35,6 +36,16 @@ type mistralProvider struct { contextCache *contextCache } +func (m *mistralProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *mistralProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *mistralProvider) GetProviderType() string { return providerTypeMistral } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index e76bcf529d..2eb834676f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -50,6 +50,16 @@ type moonshotProvider struct { contextCache *contextCache } +func (m *moonshotProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *moonshotProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *moonshotProvider) GetProviderType() string { return providerTypeMoonshot } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 8895489fbe..736af60e6f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -45,6 +46,16 @@ type ollamaProvider struct { contextCache *contextCache } +func (m *ollamaProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *ollamaProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *ollamaProvider) GetProviderType() string { return providerTypeOllama } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 42d991737f..9b57c28e94 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -2,6 +2,7 @@ package provider import ( "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -52,6 +53,16 @@ type openaiProvider struct { contextCache *contextCache } +func (m *openaiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *openaiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *openaiProvider) GetProviderType() string { return providerTypeOpenAI } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 39207afe9a..3b07b9a127 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "errors" "math/rand" + "net/http" "strings" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -106,6 +107,8 @@ var ( type Provider interface { GetProviderType() string + TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) + TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) } type RequestHeadersHandler interface { @@ -370,6 +373,23 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { return initializer.CreateProvider(pc) } +func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request *chatCompletionRequest, log wrapper.Log) error { + model := request.Model + if model == "" { + return errors.New("missing model in chat completion request") + } + + ctx.SetContext(ctxKeyOriginalRequestModel, model) + mappedModel := getMappedModel(model, c.modelMapping, log) + if mappedModel == "" { + return errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + + return nil +} + func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { mappedModel := doGetMappedModel(model, modelMapping, log) if len(mappedModel) != 0 { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 0f6b533164..09f9f58374 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "net/http" "reflect" "strings" "time" @@ -63,6 +64,16 @@ type qwenProvider struct { contextCache *contextCache } +func (m *qwenProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *qwenProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *qwenProvider) GetProviderType() string { return providerTypeQwen } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index 19060849ac..aa8b6104fb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -31,6 +31,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error { return err } +func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error { + log.Debugf("request body: %s", string(body)) + err := proxywasm.ReplaceHttpRequestBody(body) + if err != nil { + return fmt.Errorf("unable to replace the original request body: %v", err) + } + return nil +} + func insertContextMessage(request *chatCompletionRequest, content string) { fileMessage := chatMessage{ Role: roleSystem, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 67c1e001bd..6dd42e9131 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -27,6 +28,16 @@ type sparkProvider struct { contextCache *contextCache } +func (p *sparkProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (p *sparkProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + type sparkRequest struct { Model string `json:"model"` Messages []chatMessage `json:"messages"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index c2a97e7ab3..987a8106a0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -37,6 +38,16 @@ type stepfunProvider struct { contextCache *contextCache } +func (m *stepfunProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *stepfunProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *stepfunProvider) GetProviderType() string { return providerTypeStepfun } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 760fc68753..76719ed51a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -37,6 +38,16 @@ type yiProvider struct { contextCache *contextCache } +func (m *yiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *yiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *yiProvider) GetProviderType() string { return providerTypeYi } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index a5175dbc66..1fe87a63ca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -3,6 +3,7 @@ package provider import ( "errors" "fmt" + "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -36,6 +37,16 @@ type zhipuAiProvider struct { contextCache *contextCache } +func (m *zhipuAiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { + //TODO implement me + panic("implement me") +} + +func (m *zhipuAiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { + //TODO implement me + panic("implement me") +} + func (m *zhipuAiProvider) GetProviderType() string { return providerTypeZhipuAi } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 43135ec0a2..fa0d119baf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -1,6 +1,9 @@ package util -import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" +import ( + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "net/http" +) const ( HeaderContentType = "Content-Type" @@ -21,6 +24,7 @@ func CreateHeaders(kvs ...string) [][2]string { return headers } +// TODO: remove func OverwriteRequestHost(host string) error { if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost) @@ -28,6 +32,7 @@ func OverwriteRequestHost(host string) error { return proxywasm.ReplaceHttpRequestHeader(":authority", host) } +// TODO: remove func OverwriteRequestPath(path string) error { if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath) @@ -35,6 +40,7 @@ func OverwriteRequestPath(path string) error { return proxywasm.ReplaceHttpRequestHeader(":path", path) } +// TODO: remove func OverwriteRequestAuthorization(credential string) error { if exist, _ := proxywasm.GetHttpRequestHeader("X-HI-ORIGINAL-AUTH"); exist == "" { if originAuth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil { @@ -43,3 +49,57 @@ func OverwriteRequestAuthorization(credential string) error { } return proxywasm.ReplaceHttpRequestHeader("Authorization", credential) } + +func OverwriteHttpRequestHost(headers http.Header, host string) { + if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-HOST", originHost) + } + headers.Set(":authority", host) +} + +func OverwriteHttpRequestPath(headers http.Header, path string) { + if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) + } + headers.Set(":path", path) +} + +func OverwriteHttpRequestAuthorization(headers http.Header, credential string) { + if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { + if originAuth := headers.Get("Authorization"); originAuth != "" { + headers.Set("X-HI-ORIGINAL-AUTH", originAuth) + } + } + headers.Set("Authorization", credential) + +} + +func HeaderToSlice(header http.Header) [][2]string { + slice := make([][2]string, 0, len(header)) + for key, values := range header { + for _, value := range values { + slice = append(slice, [2]string{key, value}) + } + } + return slice +} + +func SliceToHeader(slice [][2]string) http.Header { + header := make(http.Header) + for _, pair := range slice { + key := pair[0] + value := pair[1] + header.Add(key, value) + } + return header +} + +func GetOriginaHttplHeaders() http.Header { + originalHeaders, _ := proxywasm.GetHttpRequestHeaders() + return SliceToHeader(originalHeaders) +} + +func ReplaceOriginalHttpHeaders(headers http.Header) { + modifiedHeaders := HeaderToSlice(headers) + _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders) +} diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go index d07fe12095..e797394b54 100644 --- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go @@ -27,14 +27,10 @@ type Cluster interface { } type RouteCluster struct { - Host string - Cluster string + Host string } func (c RouteCluster) ClusterName() string { - if c.Cluster != "" { - return c.Cluster - } routeName, err := proxywasm.GetProperty([]string{"cluster_name"}) if err != nil { proxywasm.LogErrorf("get route cluster failed, err:%v", err) @@ -49,6 +45,19 @@ func (c RouteCluster) HostName() string { return GetRequestHost() } +type TargetCluster struct { + Host string + Cluster string +} + +func (c TargetCluster) ClusterName() string { + return c.Cluster +} + +func (c TargetCluster) HostName() string { + return c.Host +} + type K8sCluster struct { ServiceName string Namespace string From a180e658d6551af58658f15504add1f1e3377cca Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 17 Oct 2024 12:48:55 +0800 Subject: [PATCH 22/31] fix readme --- plugins/wasm-go/extensions/ai-proxy/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 059877508e..61888b4d98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -78,14 +78,14 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字 `failover` 的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -|------------------|-----------------|------|-------|-----------------------| -| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | -| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值 | -| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值 | -| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | -| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | -| healthCheckModel | int | 必填 | | 健康检测使用的模型 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------|--------|------|-------|-----------------------------| +| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | +| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | +| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | +| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | +| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | +| healthCheckModel | string | 必填 | | 健康检测使用的模型 | ### 提供商特有配置 From a72a8a1a4c4c964f40ddbbfecc09f512372df261 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 17 Oct 2024 15:16:02 +0800 Subject: [PATCH 23/31] optimize --- .../extensions/ai-proxy/provider/failover.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 364edd1e03..569e5511f2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -147,10 +147,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr Cluster: healthCheckEndpoint.Cluster, }) - setParseConfig := wrapper.ParseConfigBy[any](parseConfig) - vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig) - pluginCtx := vmCtx.NewPluginContext(rand.Uint32()) - ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any]) + ctx := createHttpContext() ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) originalHeaders := util.SliceToHeader(headers) @@ -174,6 +171,14 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr return nil } +func createHttpContext() *wrapper.CommonHttpCtx[any] { + setParseConfig := wrapper.ParseConfigBy[any](parseConfig) + vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig) + pluginCtx := vmCtx.NewPluginContext(rand.Uint32()) + ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any]) + return ctx +} + func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) { data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint) if err != nil { From 6a623335a7cc9b39777da29c969879b5c15801cb Mon Sep 17 00:00:00 2001 From: Se7en Date: Sun, 3 Nov 2024 15:13:04 +0800 Subject: [PATCH 24/31] refine transform headers and body --- plugins/wasm-go/extensions/ai-proxy/main.go | 2 +- .../extensions/ai-proxy/provider/ai360.go | 16 +- .../extensions/ai-proxy/provider/azure.go | 11 - .../extensions/ai-proxy/provider/baichuan.go | 12 -- .../extensions/ai-proxy/provider/baidu.go | 11 - .../extensions/ai-proxy/provider/claude.go | 58 ++--- .../ai-proxy/provider/cloudflare.go | 11 - .../extensions/ai-proxy/provider/cohere.go | 12 -- .../extensions/ai-proxy/provider/context.go | 56 +++++ .../extensions/ai-proxy/provider/deepl.go | 11 - .../extensions/ai-proxy/provider/deepseek.go | 12 -- .../extensions/ai-proxy/provider/doubao.go | 12 -- .../extensions/ai-proxy/provider/failover.go | 61 ++++-- .../extensions/ai-proxy/provider/gemini.go | 11 - .../extensions/ai-proxy/provider/github.go | 16 +- .../extensions/ai-proxy/provider/groq.go | 31 +-- .../extensions/ai-proxy/provider/hunyuan.go | 11 - .../extensions/ai-proxy/provider/minimax.go | 11 - .../extensions/ai-proxy/provider/mistral.go | 12 -- .../extensions/ai-proxy/provider/moonshot.go | 10 - .../extensions/ai-proxy/provider/ollama.go | 12 -- .../extensions/ai-proxy/provider/openai.go | 11 - .../extensions/ai-proxy/provider/provider.go | 124 +++++++++-- .../extensions/ai-proxy/provider/qwen.go | 198 ++++++------------ .../ai-proxy/provider/request_helper.go | 34 ++- .../extensions/ai-proxy/provider/spark.go | 11 - .../extensions/ai-proxy/provider/stepfun.go | 12 -- .../extensions/ai-proxy/provider/yi.go | 12 -- .../extensions/ai-proxy/provider/zhipuai.go | 12 -- .../wasm-go/extensions/ai-proxy/util/http.go | 3 - 30 files changed, 345 insertions(+), 471 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 532fee8259..102cf015c6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -170,7 +170,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() - providerConfig.OnRequestFailed(apiTokenInUse, log) + providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index f17052e3ae..be34b0c4d9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -4,12 +4,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) // ai360Provider is the provider for 360 OpenAI service. @@ -25,16 +23,6 @@ type ai360Provider struct { contextCache *contextCache } -func (m *ai360Provider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *ai360Provider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *ai360ProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index a1f1af2301..959bd94061 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -3,7 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" "net/url" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -51,16 +50,6 @@ type azureProvider struct { serviceUrl *url.URL } -func (m *azureProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *azureProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *azureProvider) GetProviderType() string { return providerTypeAzure } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 25d063edc6..1be071e8c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -40,16 +38,6 @@ type baichuanProvider struct { contextCache *contextCache } -func (m *baichuanProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *baichuanProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *baichuanProvider) GetProviderType() string { return providerTypeBaichuan } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 0edb50e6da..d0197e4aa5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" @@ -53,16 +52,6 @@ type baiduProvider struct { contextCache *contextCache } -func (b *baiduProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (b *baiduProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (b *baiduProvider) GetProviderType() string { return providerTypeBaidu } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 26a5fc914a..ceee5b1131 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -106,15 +106,11 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - - originalHeaders := util.GetOriginaHttplHeaders() - c.TransformRequestHeaders(originalHeaders, ctx, log) - util.ReplaceOriginalHttpHeaders(originalHeaders) - + c.config.handleRequestHeaders(c, ctx, apiName, log) return types.ActionContinue, nil } -func (c *claudeProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { +func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteHttpRequestPath(headers, claudeChatCompletionPath) util.OverwriteHttpRequestHost(headers, claudeDomain) @@ -133,41 +129,12 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - - // use original protocol - if c.config.protocol == protocolOriginal { - if c.config.context == nil { - return types.ActionContinue, nil - } - - request := &claudeTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - return types.ActionContinue, nil - } - - // use openai protocol - modifiedBody, err := c.TransformRequestBody(body, ctx, log) - if err != nil { - return types.ActionContinue, err - } - err = replaceHttpJsonRequestBody(modifiedBody, log) - if err != nil { - return types.ActionContinue, err - } - - return types.ActionContinue, nil - + return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } -func (c *claudeProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { +func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return nil, err - } - - err := c.config.setRequestModel(ctx, request, log) + err := c.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { return nil, err } @@ -347,3 +314,18 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { + request := &claudeTextGenRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + + if request.System == "" { + request.System = content + } else { + request.System = content + "\n" + request.System + } + + return json.Marshal(request) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index d0fdbad1bf..b52f942b65 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -3,7 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -40,16 +39,6 @@ type cloudflareProvider struct { contextCache *contextCache } -func (c *cloudflareProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (c *cloudflareProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (c *cloudflareProvider) GetProviderType() string { return providerTypeCloudflare } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index e530a6ef9e..c3d1bc9dc7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -4,8 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -36,16 +34,6 @@ type cohereProvider struct { config ProviderConfig } -func (m *cohereProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *cohereProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - type cohereTextGenRequest struct { Message string `json:"message,omitempty"` Model string `json:"model,omitempty"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/context.go b/plugins/wasm-go/extensions/ai-proxy/provider/context.go index 2026a9818a..86d6b98124 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/context.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/context.go @@ -6,7 +6,9 @@ import ( "net/http" "net/url" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/gjson" ) @@ -57,6 +59,10 @@ type contextCache struct { content string } +type ContextInserter interface { + insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) +} + func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error { if callback == nil { return errors.New("callback is nil") @@ -98,3 +104,53 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache { timeout: providerConfig.timeout, } } + +func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error { + // get context will overwrite the original request host and path + // save the original request host and path in case they are needed for apiToken health check + ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost()) + ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath()) + + if c.loaded { + log.Debugf("context file loaded from cache") + insertContext(provider, c.content, nil, body, log) + return nil + } + + log.Infof("loading context file from %s", c.fileUrl.String()) + return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { + insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log) + return + } + c.content = string(responseBody) + c.loaded = true + log.Debugf("content: %s", c.content) + insertContext(provider, c.content, nil, body, log) + }, c.timeout) +} + +func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + + typ := provider.GetProviderType() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + + if inserter, ok := provider.(ContextInserter); ok { + body, err = inserter.insertHttpContextMessage(body, content, false) + } else { + body, err = defaultInsertHttpContextMessage(body, content) + } + + if err != nil { + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err)) + } + if err := replaceHttpJsonRequestBody(body, log); err != nil { + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 786d635fec..6ff536dd70 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -28,16 +27,6 @@ type deeplProvider struct { contextCache *contextCache } -func (d *deeplProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (d *deeplProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - // spec reference: https://developers.deepl.com/docs/v/zh/api-reference/translate/openapi-spec-for-text-translation type deeplRequest struct { // "Model" parameter is used to distinguish which service to use diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 331f7dd019..ecca678670 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -40,16 +38,6 @@ type deepseekProvider struct { contextCache *contextCache } -func (m *deepseekProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *deepseekProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *deepseekProvider) GetProviderType() string { return providerTypeDeepSeek } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 1e0db5b441..1358eebc7a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -37,16 +35,6 @@ type doubaoProvider struct { contextCache *contextCache } -func (m *doubaoProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *doubaoProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *doubaoProvider) GetProviderType() string { return providerTypeDoubao } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 569e5511f2..2336ff44a3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -63,6 +63,8 @@ const ( removeApiTokenOperation = "removeApiToken" addApiTokenRequestCountOperation = "addApiTokenRequestCount" resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" + ctxRequestHost = "requestHost" + ctxRequestPath = "requestPath" ) var ( @@ -150,11 +152,12 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr ctx := createHttpContext() ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) - originalHeaders := util.SliceToHeader(headers) - activeProvider.TransformRequestHeaders(originalHeaders, ctx, log) - modifiedHeaders := util.HeaderToSlice(originalHeaders) - modifiedBody, _ := activeProvider.TransformRequestBody(body, ctx, log) + modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log) + if err != nil { + log.Errorf("Failed to transform request headers and body: %v", err) + } + // The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode == 200 { c.handleAvailableApiToken(apiToken, log) @@ -171,6 +174,25 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr return nil } +func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) { + originalHeaders := util.SliceToHeader(headers) + if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok { + handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log) + } + modifiedHeaders := util.HeaderToSlice(originalHeaders) + + var err error + if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { + body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log) + } else { + body, err = c.defaultTransformRequestBody(ctx, body, log) + } + if err != nil { + return nil, nil, fmt.Errorf("failed to transform request body: %v", err) + } + return modifiedHeaders, body, nil +} + func createHttpContext() *wrapper.CommonHttpCtx[any] { setParseConfig := wrapper.ParseConfigBy[any](parseConfig) vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig) @@ -283,7 +305,7 @@ func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Lo // When number of request failures exceeds the threshold, // remove the apiToken from the available list and add it to the unavailable list -func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper.Log) { +func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) { failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) if err != nil { log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) @@ -307,7 +329,7 @@ func (c *ProviderConfig) handleUnavailableApiToken(apiToken string, log wrapper. addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) // Set the request host and path to shared data in case they are needed in apiToken health check - c.setHealthCheckEndpoint(log) + c.setHealthCheckEndpoint(ctx, log) } else { log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) @@ -511,9 +533,9 @@ func (c *ProviderConfig) resetSharedData() { _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) } -func (c *ProviderConfig) OnRequestFailed(apiTokenInUse string, log wrapper.Log) { +func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { if c.isFailoverEnabled() { - c.handleUnavailableApiToken(apiTokenInUse, log) + c.handleUnavailableApiToken(ctx, apiTokenInUse, log) } } @@ -533,22 +555,33 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) } -func (c *ProviderConfig) setHealthCheckEndpoint(log wrapper.Log) { +func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) { cluster, err := proxywasm.GetProperty([]string{"cluster_name"}) if err != nil { log.Errorf("Failed to get cluster_name: %v", err) } - hostPath := HealthCheckEndpoint{ - Host: wrapper.GetRequestHost(), - Path: wrapper.GetRequestPath(), + + host := wrapper.GetRequestHost() + if host == "" { + host = ctx.GetContext(ctxRequestHost).(string) + } + path := wrapper.GetRequestPath() + if path == "" { + path = ctx.GetContext(ctxRequestPath).(string) + } + + healthCheckEndpoint := HealthCheckEndpoint{ + Host: host, + Path: path, Cluster: string(cluster), } - hostPathByte, err := json.Marshal(hostPath) + + healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint) if err != nil { log.Errorf("Failed to marshal request host and path: %v", err) } - err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, hostPathByte, 0) + err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0) if err != nil { log.Errorf("Failed to set request host and path: %v", err) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index b0563dc926..b9c157d864 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" @@ -44,16 +43,6 @@ type geminiProvider struct { contextCache *contextCache } -func (g *geminiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (g *geminiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (g *geminiProvider) GetProviderType() string { return providerTypeGemini } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index ce9ea75bb6..7cf28f69cc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -4,12 +4,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) // githubProvider is the provider for GitHub OpenAI service. @@ -27,16 +25,6 @@ type githubProvider struct { contextCache *contextCache } -func (m *githubProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *githubProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *githubProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index c2a67abf62..c3eb74faf9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -1,7 +1,6 @@ package provider import ( - "encoding/json" "errors" "net/http" @@ -45,9 +44,7 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - originalHeaders := util.GetOriginaHttplHeaders() - g.TransformRequestHeaders(originalHeaders, ctx, log) - util.ReplaceOriginalHttpHeaders(originalHeaders) + g.config.handleRequestHeaders(g, ctx, apiName, log) return types.ActionContinue, nil } @@ -55,34 +52,12 @@ func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - modifiedBody, err := g.TransformRequestBody(body, ctx, log) - if err != nil { - return types.ActionContinue, err - } - err = replaceHttpJsonRequestBody(modifiedBody, log) - if err != nil { - return types.ActionContinue, err - } - return types.ActionContinue, nil + return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } -func (g *groqProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { +func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteHttpRequestPath(headers, groqChatCompletionPath) util.OverwriteHttpRequestHost(headers, groqDomain) util.OverwriteHttpRequestAuthorization(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (g *groqProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return nil, err - } - - err := g.config.setRequestModel(ctx, request, log) - if err != nil { - return nil, err - } - - return json.Marshal(request) -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 757447dcf0..7640a380b3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" @@ -110,16 +109,6 @@ type hunyuanProvider struct { contextCache *contextCache } -func (m *hunyuanProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *hunyuanProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 83c7952867..62173deeb3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -71,16 +70,6 @@ type minimaxProvider struct { contextCache *contextCache } -func (m *minimaxProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *minimaxProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *minimaxProvider) GetProviderType() string { return providerTypeMinimax } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 343bd483e4..d44d3ce55a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -36,16 +34,6 @@ type mistralProvider struct { contextCache *contextCache } -func (m *mistralProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *mistralProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *mistralProvider) GetProviderType() string { return providerTypeMistral } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 2eb834676f..e76bcf529d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -50,16 +50,6 @@ type moonshotProvider struct { contextCache *contextCache } -func (m *moonshotProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *moonshotProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *moonshotProvider) GetProviderType() string { return providerTypeMoonshot } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 736af60e6f..310dbeb3a1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -46,16 +44,6 @@ type ollamaProvider struct { contextCache *contextCache } -func (m *ollamaProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *ollamaProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *ollamaProvider) GetProviderType() string { return providerTypeOllama } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 9b57c28e94..42d991737f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -2,7 +2,6 @@ package provider import ( "fmt" - "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -53,16 +52,6 @@ type openaiProvider struct { contextCache *contextCache } -func (m *openaiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *openaiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *openaiProvider) GetProviderType() string { return providerTypeOpenAI } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 3b07b9a127..827a5e2472 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,15 +1,16 @@ package provider import ( + "encoding/json" "errors" "math/rand" "net/http" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) type ApiName string @@ -107,18 +108,24 @@ var ( type Provider interface { GetProviderType() string - TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) - TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) } type RequestHeadersHandler interface { OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) } +type TransformRequestHeadersHandler interface { + TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) +} + type RequestBodyHandler interface { OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) } +type TransformRequestBodyHandler interface { + TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) +} + type ResponseHeadersHandler interface { OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) } @@ -373,20 +380,51 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { return initializer.CreateProvider(pc) } -func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request *chatCompletionRequest, log wrapper.Log) error { - model := request.Model - if model == "" { - return errors.New("missing model in chat completion request") +func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error { + switch req := request.(type) { + case *chatCompletionRequest: + if err := decodeChatCompletionRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req, log) + case *embeddingsRequest: + if err := decodeEmbeddingsRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req, log) + default: + return errors.New("unsupported request type") + } +} + +func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error { + var model *string + + switch req := request.(type) { + case *chatCompletionRequest: + model = &req.Model + case *embeddingsRequest: + model = &req.Model + default: + return errors.New("unsupported request type") } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, c.modelMapping, log) + return c.mapModel(ctx, model, log) +} + +func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error { + if *model == "" { + return errors.New("missing model in request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, *model) + + mappedModel := getMappedModel(*model, c.modelMapping, log) if mappedModel == "" { return errors.New("model becomes empty after applying the configured mapping") } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + *model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, *model) return nil } @@ -426,3 +464,65 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } + +func (c *ProviderConfig) handleRequestBody( + provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, +) (types.Action, error) { + // use original protocol + if c.protocol == protocolOriginal { + if apiName == ApiNameChatCompletion { + if c.context == nil { + return types.ActionContinue, nil + } + err := contextCache.GetContextFromFile(ctx, provider, body, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + return types.ActionContinue, nil + } + + // use openai protocol + var err error + if handler, ok := provider.(TransformRequestBodyHandler); ok { + body, err = handler.TransformRequestBody(ctx, apiName, body, log) + } else { + body, err = c.defaultTransformRequestBody(ctx, body, log) + } + + if err != nil { + return types.ActionContinue, err + } + + if apiName == ApiNameChatCompletion { + if c.context == nil { + return types.ActionContinue, replaceHttpJsonRequestBody(body, log) + } + err = contextCache.GetContextFromFile(ctx, provider, body, log) + + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + return types.ActionContinue, replaceHttpJsonRequestBody(body, log) +} + +func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) { + if handler, ok := provider.(TransformRequestHeadersHandler); ok { + originalHeaders := util.GetOriginaHttplHeaders() + handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log) + util.ReplaceOriginalHttpHeaders(originalHeaders) + } +} + +func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + err := c.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err + } + + return json.Marshal(request) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 09f9f58374..2709e89029 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -64,14 +64,28 @@ type qwenProvider struct { contextCache *contextCache } -func (m *qwenProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") +func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteHttpRequestHost(headers, qwenDomain) + util.OverwriteHttpRequestAuthorization(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + + if m.config.qwenEnableCompatible { + util.OverwriteHttpRequestPath(headers, qwenCompatiblePath) + } else if apiName == ApiNameChatCompletion { + util.OverwriteHttpRequestPath(headers, qwenChatCompletionPath) + } else if apiName == ApiNameEmbeddings { + util.OverwriteHttpRequestPath(headers, qwenTextEmbeddingPath) + } + + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (m *qwenProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") +func (m *qwenProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return m.onChatCompletionRequestBody(ctx, body, log) + } else { + return m.onEmbeddingsRequestBody(ctx, body, log) + } } func (m *qwenProvider) GetProviderType() string { @@ -79,25 +93,17 @@ func (m *qwenProvider) GetProviderType() string { } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestHost(qwenDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName + } + + m.config.handleRequestHeaders(m, ctx, apiName, log) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() return types.ActionContinue, nil - } else if m.config.qwenEnableCompatible { - _ = util.OverwriteRequestPath(qwenCompatiblePath) - } else if apiName == ApiNameChatCompletion { - _ = util.OverwriteRequestPath(qwenChatCompletionPath) - } else if apiName == ApiNameEmbeddings { - _ = util.OverwriteRequestPath(qwenTextEmbeddingPath) - } else { - return types.ActionContinue, errUnsupportedApiName } - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -132,62 +138,20 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } return types.ActionContinue, nil } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName -} -func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - if m.config.protocol == protocolOriginal { - if m.config.context == nil { - return types.ActionContinue, nil - } - - request := &qwenTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - m.insertContextMessage(request, content, false) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} +func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + err := m.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) // Use the qwen multimodal model generation API if strings.HasPrefix(request.Model, qwenVlModelPrefixName) { _ = util.OverwriteRequestPath(qwenMultimodalGenerationPath) @@ -202,62 +166,21 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body _ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE") } - if m.config.context == nil { - qwenRequest := m.buildQwenTextGenerationRequest(request, streaming) - if streaming { - ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput) - } - return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log) - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - qwenRequest := m.buildQwenTextGenerationRequest(request, streaming) - if streaming { - ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput) - } - if err := replaceJsonRequestBody(qwenRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return m.buildQwenTextGenerationRequest(ctx, request, streaming) } -func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - - log.Debugf("=== embeddings request: %v", request) - - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in the request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + err := m.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil { - return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log) - } else { - return types.ActionContinue, err + qwenRequest, err := m.buildQwenTextEmbeddingRequest(request) + if err != nil { + return nil, err } + return json.Marshal(qwenRequest) } func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -386,7 +309,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body [] return types.ActionContinue, replaceJsonResponseBody(response, log) } -func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest { +func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) { messages := make([]qwenMessage, 0, len(origRequest.Messages)) for i := range origRequest.Messages { messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i])) @@ -408,6 +331,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio Tools: origRequest.Tools, }, } + + if streaming { + ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput) + } + if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName { builder := strings.Builder{} for _, fileId := range m.config.qwenFileIds { @@ -417,13 +345,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio builder.WriteString("fileid://") builder.WriteString(fileId) } - contextMessageId := m.insertContextMessage(request, builder.String(), true) - if contextMessageId == 0 { - // The context message cannot come first. We need to add another dummy system message before it. - request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...) + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("unable to marshal request: %v", err) } + + return m.insertHttpContextMessage(body, builder.String(), true) } - return request + return json.Marshal(request) } func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse { @@ -580,7 +510,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild return nil } -func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int { +func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { + request := &qwenTextGenRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + fileMessage := qwenMessage{ Role: roleSystem, Content: content, @@ -597,10 +532,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content } if firstNonSystemMessageIndex == 0 { request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...) - return 0 } else if !onlyOneSystemBeforeFile { request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...) - return firstNonSystemMessageIndex } else { builder := strings.Builder{} for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] { @@ -610,8 +543,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content builder.WriteString(message.StringContent()) } request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...) - return 1 + firstNonSystemMessageIndex = 1 } + + if firstNonSystemMessageIndex == 0 { + // The context message cannot come first. We need to add another dummy system message before it. + request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...) + } + + return json.Marshal(request) } func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index aa8b6104fb..defe55f784 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -3,7 +3,6 @@ package provider import ( "encoding/json" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) @@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er return nil } +func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error { + if err := json.Unmarshal(body, request); err != nil { + return fmt.Errorf("unable to unmarshal request: %v", err) + } + return nil +} + func replaceJsonRequestBody(request interface{}, log wrapper.Log) error { body, err := json.Marshal(request) if err != nil { @@ -59,6 +65,32 @@ func insertContextMessage(request *chatCompletionRequest, content string) { } } +func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) { + request := &chatCompletionRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + + fileMessage := chatMessage{ + Role: roleSystem, + Content: content, + } + var firstNonSystemMessageIndex int + for i, message := range request.Messages { + if message.Role != roleSystem { + firstNonSystemMessageIndex = i + break + } + } + if firstNonSystemMessageIndex == 0 { + request.Messages = append([]chatMessage{fileMessage}, request.Messages...) + } else { + request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...) + } + + return json.Marshal(request) +} + func replaceJsonResponseBody(response interface{}, log wrapper.Log) error { body, err := json.Marshal(response) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 6dd42e9131..67c1e001bd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" @@ -28,16 +27,6 @@ type sparkProvider struct { contextCache *contextCache } -func (p *sparkProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (p *sparkProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - type sparkRequest struct { Model string `json:"model"` Messages []chatMessage `json:"messages"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 987a8106a0..2bc80f6f81 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -38,16 +36,6 @@ type stepfunProvider struct { contextCache *contextCache } -func (m *stepfunProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *stepfunProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *stepfunProvider) GetProviderType() string { return providerTypeStepfun } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 76719ed51a..511d251553 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -38,16 +36,6 @@ type yiProvider struct { contextCache *contextCache } -func (m *yiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *yiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *yiProvider) GetProviderType() string { return providerTypeYi } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 1fe87a63ca..98c824e4f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -3,8 +3,6 @@ package provider import ( "errors" "fmt" - "net/http" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -37,16 +35,6 @@ type zhipuAiProvider struct { contextCache *contextCache } -func (m *zhipuAiProvider) TransformRequestHeaders(headers http.Header, ctx wrapper.HttpContext, log wrapper.Log) { - //TODO implement me - panic("implement me") -} - -func (m *zhipuAiProvider) TransformRequestBody(body []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]byte, error) { - //TODO implement me - panic("implement me") -} - func (m *zhipuAiProvider) GetProviderType() string { return providerTypeZhipuAi } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index fa0d119baf..2c8156c413 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -24,7 +24,6 @@ func CreateHeaders(kvs ...string) [][2]string { return headers } -// TODO: remove func OverwriteRequestHost(host string) error { if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost) @@ -32,7 +31,6 @@ func OverwriteRequestHost(host string) error { return proxywasm.ReplaceHttpRequestHeader(":authority", host) } -// TODO: remove func OverwriteRequestPath(path string) error { if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath) @@ -40,7 +38,6 @@ func OverwriteRequestPath(path string) error { return proxywasm.ReplaceHttpRequestHeader(":path", path) } -// TODO: remove func OverwriteRequestAuthorization(credential string) error { if exist, _ := proxywasm.GetHttpRequestHeader("X-HI-ORIGINAL-AUTH"); exist == "" { if originAuth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil { From f1f375ef3b22489900f5ef1c29ab3d856ed179aa Mon Sep 17 00:00:00 2001 From: Se7en Date: Sun, 3 Nov 2024 15:33:21 +0800 Subject: [PATCH 25/31] move defaultInsertHttpContextMessage to context.go --- .../extensions/ai-proxy/provider/context.go | 27 +++++++++++++++++++ .../ai-proxy/provider/request_helper.go | 26 ------------------ 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/context.go b/plugins/wasm-go/extensions/ai-proxy/provider/context.go index 86d6b98124..d9fe2e26c4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/context.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/context.go @@ -1,6 +1,7 @@ package provider import ( + "encoding/json" "errors" "fmt" "net/http" @@ -154,3 +155,29 @@ func insertContext(provider Provider, content string, err error, body []byte, lo _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) } } + +func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) { + request := &chatCompletionRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + + fileMessage := chatMessage{ + Role: roleSystem, + Content: content, + } + var firstNonSystemMessageIndex int + for i, message := range request.Messages { + if message.Role != roleSystem { + firstNonSystemMessageIndex = i + break + } + } + if firstNonSystemMessageIndex == 0 { + request.Messages = append([]chatMessage{fileMessage}, request.Messages...) + } else { + request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...) + } + + return json.Marshal(request) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index defe55f784..dd9864702e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -65,32 +65,6 @@ func insertContextMessage(request *chatCompletionRequest, content string) { } } -func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) { - request := &chatCompletionRequest{} - if err := json.Unmarshal(body, request); err != nil { - return nil, fmt.Errorf("unable to unmarshal request: %v", err) - } - - fileMessage := chatMessage{ - Role: roleSystem, - Content: content, - } - var firstNonSystemMessageIndex int - for i, message := range request.Messages { - if message.Role != roleSystem { - firstNonSystemMessageIndex = i - break - } - } - if firstNonSystemMessageIndex == 0 { - request.Messages = append([]chatMessage{fileMessage}, request.Messages...) - } else { - request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...) - } - - return json.Marshal(request) -} - func replaceJsonResponseBody(response interface{}, log wrapper.Log) error { body, err := json.Marshal(response) if err != nil { From 02961104c3be4c8b98147c34d5ba62721f3d555a Mon Sep 17 00:00:00 2001 From: Se7en Date: Tue, 5 Nov 2024 21:21:19 +0800 Subject: [PATCH 26/31] fix --- plugins/wasm-go/extensions/ai-proxy/provider/provider.go | 2 +- plugins/wasm-go/extensions/ai-proxy/util/http.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 827a5e2472..0aec59d487 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -511,7 +511,7 @@ func (c *ProviderConfig) handleRequestBody( func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) { if handler, ok := provider.(TransformRequestHeadersHandler); ok { - originalHeaders := util.GetOriginaHttplHeaders() + originalHeaders := util.GetOriginalHttpHeaders() handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log) util.ReplaceOriginalHttpHeaders(originalHeaders) } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 2c8156c413..edaf60506b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -1,8 +1,9 @@ package util import ( - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "net/http" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) const ( @@ -68,7 +69,6 @@ func OverwriteHttpRequestAuthorization(headers http.Header, credential string) { } } headers.Set("Authorization", credential) - } func HeaderToSlice(header http.Header) [][2]string { @@ -91,7 +91,7 @@ func SliceToHeader(slice [][2]string) http.Header { return header } -func GetOriginaHttplHeaders() http.Header { +func GetOriginalHttpHeaders() http.Header { originalHeaders, _ := proxywasm.GetHttpRequestHeaders() return SliceToHeader(originalHeaders) } From f1648544daa79b4bbc11246d603ee02554440f84 Mon Sep 17 00:00:00 2001 From: Se7en Date: Tue, 5 Nov 2024 21:34:20 +0800 Subject: [PATCH 27/31] remove get context in original protocol --- .../wasm-go/extensions/ai-proxy/provider/provider.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 0aec59d487..a5a75e8027 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -470,16 +470,6 @@ func (c *ProviderConfig) handleRequestBody( ) (types.Action, error) { // use original protocol if c.protocol == protocolOriginal { - if apiName == ApiNameChatCompletion { - if c.context == nil { - return types.ActionContinue, nil - } - err := contextCache.GetContextFromFile(ctx, provider, body, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } return types.ActionContinue, nil } From 0938f983dad01763f63a8ff05e075492ac76ac14 Mon Sep 17 00:00:00 2001 From: Se7en Date: Wed, 6 Nov 2024 22:30:04 +0800 Subject: [PATCH 28/31] add reset apiToken log --- plugins/wasm-go/extensions/ai-proxy/provider/failover.go | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 2336ff44a3..c43f0fc622 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -462,6 +462,7 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log.Errorf("failed to get failureApiTokenRequestCount: %v", err) } if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { + log.Infof("reset apiToken %s request failure count", apiTokenInUse) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) } } From 8e425c30e26c10cb542f0fe4e69bdcd0943ca99d Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 14 Nov 2024 11:14:04 +0800 Subject: [PATCH 29/31] add GetApiName to determine apiName for original protocol --- plugins/wasm-go/extensions/ai-proxy/main.go | 21 +- .../extensions/ai-proxy/provider/ai360.go | 62 ++---- .../extensions/ai-proxy/provider/azure.go | 61 +++--- .../extensions/ai-proxy/provider/baichuan.go | 46 ++--- .../extensions/ai-proxy/provider/baidu.go | 108 +++------- .../extensions/ai-proxy/provider/claude.go | 20 +- .../ai-proxy/provider/cloudflare.go | 69 ++----- .../extensions/ai-proxy/provider/cohere.go | 68 +++---- .../extensions/ai-proxy/provider/deepl.go | 89 +++++---- .../extensions/ai-proxy/provider/deepseek.go | 46 ++--- .../extensions/ai-proxy/provider/doubao.go | 66 ++---- .../extensions/ai-proxy/provider/failover.go | 9 +- .../extensions/ai-proxy/provider/gemini.go | 189 +++++------------- .../extensions/ai-proxy/provider/github.go | 67 ++----- .../extensions/ai-proxy/provider/groq.go | 14 +- .../extensions/ai-proxy/provider/hunyuan.go | 70 ++++--- .../extensions/ai-proxy/provider/minimax.go | 95 +++------ .../extensions/ai-proxy/provider/mistral.go | 47 ++--- .../extensions/ai-proxy/provider/moonshot.go | 34 ++-- .../extensions/ai-proxy/provider/ollama.go | 63 ++---- .../extensions/ai-proxy/provider/openai.go | 57 +++--- .../extensions/ai-proxy/provider/provider.go | 43 +++- .../extensions/ai-proxy/provider/qwen.go | 37 ++-- .../extensions/ai-proxy/provider/spark.go | 54 ++--- .../extensions/ai-proxy/provider/stepfun.go | 47 ++--- .../extensions/ai-proxy/provider/yi.go | 47 ++--- .../extensions/ai-proxy/provider/zhipuai.go | 47 ++--- .../wasm-go/extensions/ai-proxy/util/http.go | 13 +- 28 files changed, 609 insertions(+), 980 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 102cf015c6..3a29575c21 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -80,9 +80,16 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - apiName := getOpenAiApiName(path.Path) + + var apiName provider.ApiName providerConfig := pluginConfig.GetProviderConfig() - if apiName == "" && !providerConfig.IsOriginal() { + if providerConfig.IsOriginal() { + apiName = activeProvider.GetApiName(path.Path) + } else { + apiName = provider.GetOpenAiApiName(path.Path) + } + + if apiName == "" { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) return types.ActionContinue @@ -247,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi return types.ActionContinue } -func getOpenAiApiName(path string) provider.ApiName { - if strings.HasSuffix(path, "/v1/chat/completions") { - return provider.ApiNameChatCompletion - } - if strings.HasSuffix(path, "/v1/embeddings") { - return provider.ApiNameEmbeddings - } - return "" -} - func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index be34b0c4d9..439fd8ff10 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -1,18 +1,19 @@ package provider import ( - "encoding/json" "errors" - "fmt" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) // ai360Provider is the provider for 360 OpenAI service. const ( - ai360Domain = "api.360.cn" + ai360Domain = "api.360.cn" + ai360ChatCompletionPath = "/v1/chat/completions" ) type ai360ProviderInitializer struct { @@ -45,10 +46,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(ai360Domain) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetApiTokenInUse(ctx)) + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -57,47 +55,19 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } -func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) +func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, ai360Domain) + util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") +func (m *ai360Provider) GetApiName(path string) ApiName { + if strings.Contains(path, ai360ChatCompletionPath) { + return ApiNameChatCompletion } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 959bd94061..9919aeb073 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -3,16 +3,20 @@ package provider import ( "errors" "fmt" + "net/http" "net/url" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) -// azureProvider is the provider for Azure OpenAI service. +const ( + azureChatCompletionPath = "/chat/completions" +) +// azureProvider is the provider for Azure OpenAI service. type azureProviderInitializer struct { } @@ -55,47 +59,30 @@ func (m *azureProvider) GetProviderType() string { } func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestPath(m.serviceUrl.RequestURI()) - _ = util.OverwriteRequestHost(m.serviceUrl.Host) - _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.GetApiTokenInUse(ctx)) - if apiName == ApiNameChatCompletion { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } else { - ctx.DontReadRequestBody() + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName } + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { - // We don't need to process the request body for other APIs. - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) + util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) + util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *azureProvider) GetApiName(path string) ApiName { + if strings.Contains(path, azureChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 1be071e8c9..3ed477602c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // baichuanProvider is the provider for baichuan Ai service. @@ -46,10 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(baichuanChatCompletionPath) - _ = util.OverwriteRequestHost(baichuanDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -57,28 +54,19 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath) + util.OverwriteRequestHostHeader(headers, baichuanDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *baichuanProvider) GetApiName(path string) ApiName { + if strings.Contains(path, baichuanChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index d0197e4aa5..42a1bc723d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -16,7 +17,8 @@ import ( // baiduProvider is the provider for baidu ernie bot service. const ( - baiduDomain = "aip.baidubce.com" + baiduDomain = "aip.baidubce.com" + baiduChatCompletionPath = "/chat" ) var baiduModelToPathSuffixMap = map[string]string{ @@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(baiduDomain) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + b.config.handleRequestHeaders(b, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } +func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, baiduDomain) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} + func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - // 使用文心一言接口协议 - if b.config.protocol == protocolOriginal { - request := &baiduTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := b.getRequestPath(ctx, request.Model) - _ = util.OverwriteRequestPath(path) - - if b.config.context == nil { - return types.ActionContinue, nil - } + return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log) +} - err := b.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - b.setSystemContent(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } +func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + err := b.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } + path := b.getRequestPath(ctx, request.Model) + util.OverwriteRequestPathHeader(headers, path) - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, b.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := b.getRequestPath(ctx, mappedModel) - _ = util.OverwriteRequestPath(path) - - if b.config.context == nil { - baiduRequest := b.baiduTextGenRequest(request) - return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log) - } - - err := b.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - baiduRequest := b.baiduTextGenRequest(request) - if err := replaceJsonRequestBody(baiduRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + baiduRequest := b.baiduTextGenRequest(request) + return json.Marshal(baiduRequest) } func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (b *baiduProvider) GetApiName(path string) ApiName { + if strings.Contains(path, baiduChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index ceee5b1131..8b98d62d64 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -111,8 +111,8 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteHttpRequestPath(headers, claudeChatCompletionPath) - util.OverwriteHttpRequestHost(headers, claudeDomain) + util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) + util.OverwriteRequestHostHeader(headers, claudeDomain) headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) @@ -134,16 +134,9 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - err := c.config.parseRequestAndMapModel(ctx, request, body, log) - if err != nil { + if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err } - - streaming := request.Stream - if streaming { - _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") - } - claudeRequest := c.buildClaudeTextGenRequest(request) return json.Marshal(claudeRequest) } @@ -329,3 +322,10 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o return json.Marshal(request) } + +func (c *claudeProvider) GetApiName(path string) ApiName { + if strings.Contains(path, claudeChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index b52f942b65..a4c02381fa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -2,19 +2,19 @@ package provider import ( "errors" - "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) const ( cloudflareDomain = "api.cloudflare.com" // https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/ - cloudflareChatCompletionPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" + cloudflareChatCompletionPath = "/v1/chat/completions" + cloudflareChatCompletionFullPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" ) type cloudflareProviderInitializer struct { @@ -47,13 +47,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) - _ = util.OverwriteRequestHost(cloudflareDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetApiTokenInUse(ctx)) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + c.config.handleRequestHeaders(c, ctx, apiName, log) return types.ActionContinue, nil } @@ -61,49 +55,20 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) +} - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, c.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - - streaming := request.Stream - if streaming { - _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") - } +func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionFullPath, "{account_id}", c.config.cloudflareAccountId, 1)) + util.OverwriteRequestHostHeader(headers, cloudflareDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} - if c.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil +func (c *cloudflareProvider) GetApiName(path string) ApiName { + if strings.Contains(path, cloudflareChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index c3d1bc9dc7..72dbaf280b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -3,16 +3,16 @@ package provider import ( "encoding/json" "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( - cohereDomain = "api.cohere.com" - chatCompletionPath = "/v1/chat" + cohereDomain = "api.cohere.com" + cohereChatCompletionPath = "/v1/chat" ) type cohereProviderInitializer struct{} @@ -26,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &cohereProvider{ - config: config, + config: config, + contextCache: createContextCache(&config), }, nil } type cohereProvider struct { - config ProviderConfig + config ProviderConfig + contextCache *contextCache } type cohereTextGenRequest struct { @@ -56,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(cohereDomain) - _ = util.OverwriteRequestPath(chatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -67,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.config.protocol == protocolOriginal { - request := &cohereTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - return m.handleRequestBody(log, request) - } - origin := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, origin); err != nil { - return types.ActionContinue, err - } - request := m.buildCohereRequest(origin) - return m.handleRequestBody(log, request) -} - -func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - err := replaceJsonRequestBody(request, log) - if err != nil { - _ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, err + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest { @@ -111,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe PresencePenalty: origin.PresencePenalty, } } + +func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) + util.OverwriteRequestHostHeader(headers, cohereDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err + } + + cohereRequest := m.buildCohereRequest(request) + return json.Marshal(cohereRequest) +} + +func (m *cohereProvider) GetApiName(path string) ApiName { + if strings.Contains(path, cohereChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 6ff536dd70..bafe6b3dde 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(deeplChatCompletionPath) - _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + d.config.handleRequestHeaders(d, ctx, apiName, log) return types.HeaderStopIteration, nil } +func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath) + util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") + headers.Del("Accept-Encoding") +} + func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if d.config.protocol == protocolOriginal { - request := &deeplRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if err := d.overwriteRequestHost(request.Model); err != nil { - return types.ActionContinue, err - } - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - return types.ActionContinue, replaceJsonRequestBody(request, log) - } else { - originRequest := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, originRequest); err != nil { - return types.ActionContinue, err - } - if err := d.overwriteRequestHost(originRequest.Model); err != nil { - return types.ActionContinue, err - } - ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model) - deeplRequest := &deeplRequest{ - Text: make([]string, 0), - TargetLang: d.config.targetLang, - } - for _, msg := range originRequest.Messages { - if msg.Role == roleSystem { - deeplRequest.Context = msg.StringContent() - } else { - deeplRequest.Text = append(deeplRequest.Text, msg.StringContent()) - } - } - return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log) + return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) +} + +func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return nil, err } + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + + err := d.overwriteRequestHost(headers, request.Model) + if err != nil { + return nil, err + } + + baiduRequest := d.deeplTextGenRequest(request) + return json.Marshal(baiduRequest) } func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo } } -func (d *deeplProvider) overwriteRequestHost(model string) error { +func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error { if model == "Pro" { - _ = util.OverwriteRequestHost(deeplHostPro) + util.OverwriteRequestHostHeader(headers, deeplHostPro) } else if model == "Free" { - _ = util.OverwriteRequestHost(deeplHostFree) + util.OverwriteRequestHostHeader(headers, deeplHostFree) } else { return errors.New(`deepl model should be "Free" or "Pro"`) } return nil } + +func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest { + deeplRequest := &deeplRequest{ + Text: make([]string, 0), + TargetLang: d.config.targetLang, + } + for _, msg := range request.Messages { + if msg.Role == roleSystem { + deeplRequest.Context = msg.StringContent() + } else { + deeplRequest.Text = append(deeplRequest.Text, msg.StringContent()) + } + } + return deeplRequest +} + +func (d *deeplProvider) GetApiName(path string) ApiName { + if strings.Contains(path, deeplChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index ecca678670..c1eb57fe45 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // deepseekProvider is the provider for deepseek Ai service. @@ -46,10 +46,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(deepseekChatCompletionPath) - _ = util.OverwriteRequestHost(deepseekDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -57,28 +54,19 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath) + util.OverwriteRequestHostHeader(headers, deepseekDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *deepseekProvider) GetApiName(path string) ApiName { + if strings.Contains(path, deepseekChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 1358eebc7a..651b983206 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( @@ -40,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string { } func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestHost(doubaoDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - if m.config.protocol == protocolOriginal { - ctx.DontReadRequestBody() - return types.ActionContinue, nil - } if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(doubaoChatCompletionPath) + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -58,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - if m.contextCache != nil { - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } else { - return types.ActionContinue, err - } - } else { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - return types.ActionContinue, err - } - _ = proxywasm.ResumeHttpRequest() - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath) + util.OverwriteRequestHostHeader(headers, doubaoDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *doubaoProvider) GetApiName(path string) ApiName { + if strings.Contains(path, doubaoChatCompletionPath) { + return ApiNameChatCompletion } + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index c43f0fc622..32e92a4db4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -179,17 +179,22 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok { handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log) } - modifiedHeaders := util.HeaderToSlice(originalHeaders) var err error if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log) + } else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok { + headers := util.GetOriginalHttpHeaders() + body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log) + util.ReplaceOriginalHttpHeaders(headers) } else { - body, err = c.defaultTransformRequestBody(ctx, body, log) + body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log) } if err != nil { return nil, nil, fmt.Errorf("failed to transform request body: %v", err) } + + modifiedHeaders := util.HeaderToSlice(originalHeaders) return modifiedHeaders, body, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index b9c157d864..a4c1ef2cd9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -17,8 +18,11 @@ import ( // geminiProvider is the provider for google gemini/gemini flash service. const ( - geminiApiKeyHeader = "x-goog-api-key" - geminiDomain = "generativelanguage.googleapis.com" + geminiApiKeyHeader = "x-goog-api-key" + geminiDomain = "generativelanguage.googleapis.com" + geminiChatCompletionPath = "generateContent" + geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse" + geminiEmbeddingPath = "batchEmbedContents" ) type geminiProviderInitializer struct { @@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - - _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) - _ = util.OverwriteRequestHost(geminiDomain) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } -func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName == ApiNameChatCompletion { - return g.onChatCompletionRequestBody(ctx, body, log) - } else if apiName == ApiNameEmbeddings { - return g.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName +func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, geminiDomain) + headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - // 使用gemini接口协议 - if g.config.protocol == protocolOriginal { - request := &geminiChatRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) - _ = util.OverwriteRequestPath(path) - - // 移除多余的model和stream字段 - request = &geminiChatRequest{ - Contents: request.Contents, - SafetySettings: request.SafetySettings, - GenerationConfig: request.GenerationConfig, - Tools: request.Tools, - } - if g.config.context == nil { - return types.ActionContinue, replaceJsonRequestBody(request, log) - } - - err := g.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - g.setSystemContent(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err +func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName } + return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) +} - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, g.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") +func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return g.onChatCompletionRequestBody(ctx, body, headers, log) + } else { + return g.onEmbeddingsRequestBody(ctx, body, headers, log) } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream) - _ = util.OverwriteRequestPath(path) +} - if g.config.context == nil { - geminiRequest := g.buildGeminiChatRequest(request) - return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) +func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + err := g.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } + path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) - err := g.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - geminiRequest := g.buildGeminiChatRequest(request) - if err := replaceJsonRequestBody(geminiRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + geminiRequest := g.buildGeminiChatRequest(request) + return json.Marshal(geminiRequest) } -func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - // 使用gemini接口协议 - if g.config.protocol == protocolOriginal { - request := &geminiBatchEmbeddingRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := g.getRequestPath(ApiNameEmbeddings, request.Model, false) - _ = util.OverwriteRequestPath(path) - - // 移除多余的model字段 - request = &geminiBatchEmbeddingRequest{ - Requests: request.Requests, - } - return types.ActionContinue, replaceJsonRequestBody(request, log) - } +func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err } - - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, g.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false) - _ = util.OverwriteRequestPath(path) + path := g.getRequestPath(ApiNameEmbeddings, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) geminiRequest := g.buildBatchEmbeddingRequest(request) - return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) + return json.Marshal(geminiRequest) } func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { action := "" if apiName == ApiNameEmbeddings { - action = "batchEmbedContents" + action = geminiEmbeddingPath } else if stream { - action = "streamGenerateContent?alt=sse" + action = geminiChatCompletionStreamPath } else { - action = "generateContent" + action = geminiChatCompletionPath } return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action) } @@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (g *geminiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) { + return ApiNameChatCompletion + } + if strings.Contains(path, geminiEmbeddingPath) { + return ApiNameEmbeddings + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 7cf28f69cc..0a2b0c84de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -1,13 +1,12 @@ package provider import ( - "encoding/json" "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // githubProvider is the provider for GitHub OpenAI service. @@ -47,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(githubDomain) - if apiName == ApiNameChatCompletion { - _ = util.OverwriteRequestPath(githubCompletionPath) - } - if apiName == ApiNameEmbeddings { - _ = util.OverwriteRequestPath(githubEmbeddingPath) - } - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetApiTokenInUse(ctx)) + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -65,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } -func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") +func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, githubDomain) + if apiName == ApiNameChatCompletion { + util.OverwriteRequestPathHeader(headers, githubCompletionPath) } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + if apiName == ApiNameEmbeddings { + util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") +func (m *githubProvider) GetApiName(path string) ApiName { + if strings.Contains(path, githubCompletionPath) { + return ApiNameChatCompletion } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + if strings.Contains(path, githubEmbeddingPath) { + return ApiNameEmbeddings } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index c3eb74faf9..dfbd971261 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -3,6 +3,7 @@ package provider import ( "errors" "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -56,8 +57,15 @@ func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteHttpRequestPath(headers, groqChatCompletionPath) - util.OverwriteHttpRequestHost(headers, groqDomain) - util.OverwriteHttpRequestAuthorization(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) + util.OverwriteRequestPathHeader(headers, groqChatCompletionPath) + util.OverwriteRequestHostHeader(headers, groqDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } + +func (g *groqProvider) GetApiName(path string) ApiName { + if strings.Contains(path, groqChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 7640a380b3..99cb135db6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string { } func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId) if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + m.config.handleRequestHeaders(m, ctx, apiName, log) + // Delay the header processing to allow changing streaming mode in OnRequestBody + return types.HeaderStopIteration, nil +} - _ = util.OverwriteRequestHost(hunyuanDomain) - _ = util.OverwriteRequestPath(hunyuanRequestPath) - - // 添加hunyuan需要的自定义字段 - _ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction) - _ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue) +func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, hunyuanDomain) + util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) - // 删除一些字段 - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + // 添加 hunyuan 需要的自定义字段 + headers.Add(actionKey, hunyuanChatCompletionTCAction) + headers.Add(versionKey, versionValue) - // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } +// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName @@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳) var timestamp int64 = time.Now().Unix() _ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp)) - // log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp) // 使用混元本身接口的协议 if m.config.protocol == protocolOriginal { @@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err } - // log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!") model := request.Model if model == "" { @@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName string(body), ) _ = util.OverwriteRequestAuthorization(authorizedValueNew) - // log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body)) - - // // 打印所有的headers - // headers, err2 := proxywasm.GetHttpRequestHeaders() - // if err2 != nil { - // log.Errorf("failed to get request headers: %v", err2) - // } else { - // // 迭代并打印所有请求头 - // for _, header := range headers { - // log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1]) - // } - // } return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log) } @@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName return types.ActionContinue, err } +// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 +func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + err := m.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err + } + + hunyuanRequest := m.buildHunyuanTextGenerationRequest(request) + + var timestamp int64 = time.Now().Unix() + _ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp)) + // 根据确定好的payload进行签名: + body, _ = json.Marshal(hunyuanRequest) + authorizedValueNew := GetTC3Authorizationcode( + m.config.hunyuanAuthId, + m.config.hunyuanAuthKey, + timestamp, + hunyuanDomain, + hunyuanChatCompletionTCAction, + string(body), + ) + util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew) + return json.Marshal(hunyuanRequest) +} + func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = proxywasm.RemoveHttpResponseHeader("Content-Length") return types.ActionContinue, nil @@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64, // fmt.Println(curl) return authorization } + +func (m *hunyuanProvider) GetApiName(path string) ApiName { + return ApiNameChatCompletion +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 62173deeb3..00aa0f7254 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(minimaxDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } +func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, minimaxDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName @@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName return m.handleRequestBodyByChatCompletionPro(body, log) } else { // 使用ChatCompletion v2接口 - return m.handleRequestBodyByChatCompletionV2(body, log) + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } } +func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + return m.handleRequestBodyByChatCompletionV2(body, headers, log) +} + // handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体 func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) { - // 使用minimax接口协议 - if m.config.protocol == protocolOriginal { - request := &minimaxChatCompletionV2Request{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - if m.config.minimaxGroupId == "" { - return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model)) - } - _ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId)) - - if m.config.context == nil { - return types.ActionContinue, nil - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - m.setBotSettings(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } - request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err @@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log log.Errorf("failed to load context file: %v", err) _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) } + // 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一 + // 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息 + // minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息 minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content) if err := replaceJsonRequestBody(minimaxRequest, log); err != nil { _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err)) @@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log } // handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体 -func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) { +func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } // 映射模型重写requestPath request.Model = getMappedModel(request.Model, m.config.modelMapping, log) - _ = util.OverwriteRequestPath(minimaxChatCompletionV2Path) - - if m.contextCache == nil { - return types.ActionContinue, replaceJsonRequestBody(request, log) - } + util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path) - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return body, nil } func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (m *minimaxProvider) GetApiName(path string) ApiName { + if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index d44d3ce55a..23b870cff0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -2,15 +2,16 @@ package provider import ( "errors" - "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( - mistralDomain = "api.mistral.ai" + mistralDomain = "api.mistral.ai" + mistralChatCompletionPath = "/v1/chat/completions" ) type mistralProviderInitializer struct{} @@ -42,9 +43,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(mistralDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -52,28 +51,18 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, mistralDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *mistralProvider) GetApiName(path string) ApiName { + if strings.Contains(path, mistralChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index e76bcf529d..de40471c92 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -58,33 +59,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(moonshotChatCompletionPath) - _ = util.OverwriteRequestHost(moonshotDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } +func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath) + util.OverwriteRequestHostHeader(headers, moonshotDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法 +// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return types.ActionContinue, err } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - if m.config.moonshotFileId == "" && m.contextCache == nil { return types.ActionContinue, replaceJsonRequestBody(request, log) } @@ -154,3 +151,10 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba return errors.New("unsupported method: " + method) } } + +func (m *moonshotProvider) GetApiName(path string) ApiName { + if strings.Contains(path, moonshotChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 310dbeb3a1..3f1303d750 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -5,8 +5,9 @@ import ( "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // ollamaProvider is the provider for Ollama service. @@ -52,10 +53,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(ollamaChatCompletionPath) - _ = util.OverwriteRequestHost(m.serviceDomain) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -63,51 +61,18 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} - if m.config.modelMapping == nil && m.contextCache == nil { - return types.ActionContinue, nil - } - - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel +func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath) + util.OverwriteRequestHostHeader(headers, m.serviceDomain) + headers.Del("Content-Length") +} - if m.contextCache != nil { - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } else { - return types.ActionContinue, err - } - } else { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - return types.ActionContinue, err - } - _ = proxywasm.ResumeHttpRequest() - return types.ActionPause, nil +func (m *ollamaProvider) GetApiName(path string) ApiName { + if strings.Contains(path, ollamaChatCompletionPath) { + return ApiNameChatCompletion } + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 42d991737f..ab92191dfe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -1,12 +1,13 @@ package provider import ( + "encoding/json" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string { } func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + m.config.handleRequestHeaders(m, ctx, apiName, log) + return types.ActionContinue, nil +} + +func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { if m.customPath == "" { switch apiName { case ApiNameChatCompletion: - _ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath) + util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath) case ApiNameEmbeddings: ctx.DontReadRequestBody() - _ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath) + util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath) } } else { - _ = util.OverwriteRequestPath(m.customPath) + util.OverwriteRequestPathHeader(headers, m.customPath) } if m.customDomain == "" { - _ = util.OverwriteRequestHost(defaultOpenaiDomain) + util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain) } else { - _ = util.OverwriteRequestHost(m.customDomain) + util.OverwriteRequestHostHeader(headers, m.customDomain) } if len(m.config.apiTokens) > 0 { - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) } - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - return types.ActionContinue, nil + headers.Del("Content-Length") } func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, // We don't need to process the request body for other APIs. return types.ActionContinue, nil } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } if m.config.responseJsonSchema != nil { log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema) @@ -101,27 +110,9 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, request.StreamOptions.IncludeUsage = true } } - if m.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return json.Marshal(request) +} + +func (m *openaiProvider) GetApiName(path string) ApiName { + return GetOpenAiApiName(path) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index a5a75e8027..e1080cf7ce 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "encoding/json" "errors" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "math/rand" "net/http" "strings" @@ -108,6 +109,7 @@ var ( type Provider interface { GetProviderType() string + GetApiName(path string) ApiName } type RequestHeadersHandler interface { @@ -126,6 +128,12 @@ type TransformRequestBodyHandler interface { TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) } +// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body. +// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model). +type TransformRequestBodyHeadersHandler interface { + TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) +} + type ResponseHeadersHandler interface { OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) } @@ -386,6 +394,12 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques if err := decodeChatCompletionRequest(body, req); err != nil { return err } + + streaming := req.Stream + if streaming { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + } + return c.setRequestModel(ctx, req, log) case *embeddingsRequest: if err := decodeEmbeddingsRequest(body, req); err != nil { @@ -477,8 +491,12 @@ func (c *ProviderConfig) handleRequestBody( var err error if handler, ok := provider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, apiName, body, log) + } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { + headers := util.GetOriginalHttpHeaders() + body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log) + util.ReplaceOriginalHttpHeaders(headers) } else { - body, err = c.defaultTransformRequestBody(ctx, body, log) + body, err = c.defaultTransformRequestBody(ctx, apiName, body, log) } if err != nil { @@ -507,12 +525,25 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt } } -func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - err := c.parseRequestAndMapModel(ctx, request, body, log) - if err != nil { +func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + var request interface{} + if apiName == ApiNameChatCompletion { + request = &chatCompletionRequest{} + } else { + request = &embeddingsRequest{} + } + if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err } - return json.Marshal(request) } + +func GetOpenAiApiName(path string) ApiName { + if strings.HasSuffix(path, "/v1/chat/completions") { + return ApiNameChatCompletion + } + if strings.HasSuffix(path, "/v1/embeddings") { + return ApiNameEmbeddings + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 2709e89029..771feeb51e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -59,30 +59,29 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide } type qwenProvider struct { - config ProviderConfig - + config ProviderConfig contextCache *contextCache } func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteHttpRequestHost(headers, qwenDomain) - util.OverwriteHttpRequestAuthorization(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + util.OverwriteRequestHostHeader(headers, qwenDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) if m.config.qwenEnableCompatible { - util.OverwriteHttpRequestPath(headers, qwenCompatiblePath) + util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) } else if apiName == ApiNameChatCompletion { - util.OverwriteHttpRequestPath(headers, qwenChatCompletionPath) + util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) } else if apiName == ApiNameEmbeddings { - util.OverwriteHttpRequestPath(headers, qwenTextEmbeddingPath) + util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) } headers.Del("Accept-Encoding") headers.Del("Content-Length") } -func (m *qwenProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { +func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) + return m.onChatCompletionRequestBody(ctx, body, headers, log) } else { return m.onEmbeddingsRequestBody(ctx, body, log) } @@ -145,7 +144,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } -func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { +func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { @@ -154,7 +153,7 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body // Use the qwen multimodal model generation API if strings.HasPrefix(request.Model, qwenVlModelPrefixName) { - _ = util.OverwriteRequestPath(qwenMultimodalGenerationPath) + util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath) } streaming := request.Stream @@ -171,8 +170,7 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { request := &embeddingsRequest{} - err := m.config.parseRequestAndMapModel(ctx, request, body, log) - if err != nil { + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err } @@ -755,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage { } } } + +func (m *qwenProvider) GetApiName(path string) ApiName { + switch { + case strings.Contains(path, qwenChatCompletionPath), + strings.Contains(path, qwenMultimodalGenerationPath), + strings.Contains(path, qwenCompatiblePath): + return ApiNameChatCompletion + case strings.Contains(path, qwenTextEmbeddingPath): + return ApiNameEmbeddings + default: + return "" + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 67c1e001bd..e39bdaded9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -2,8 +2,8 @@ package provider import ( "encoding/json" - "errors" "fmt" + "net/http" "strings" "time" @@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(sparkHost) - _ = util.OverwriteRequestPath(sparkChatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + p.config.handleRequestHeaders(p, ctx, apiName, log) return types.ActionContinue, nil } @@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - // 使用Spark协议 - if p.config.protocol == protocolOriginal { - request := &sparkRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 目前星火在模型名称错误时,也会调用generalv3,这里还是按照输入的模型名称设置响应里的模型名称 - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - return types.ActionContinue, replaceJsonRequestBody(request, log) - } else { - // 使用openai协议 - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, p.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) - } + return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -205,3 +172,18 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) + util.OverwriteRequestHostHeader(headers, sparkHost) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} + +func (p *sparkProvider) GetApiName(path string) ApiName { + if strings.Contains(path, sparkChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 2bc80f6f81..f96e59e65b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -2,10 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -44,10 +45,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(stepfunChatCompletionPath) - _ = util.OverwriteRequestHost(stepfunDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -55,28 +53,19 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath) + util.OverwriteRequestHostHeader(headers, stepfunDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *stepfunProvider) GetApiName(path string) ApiName { + if strings.Contains(path, stepfunChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 511d251553..ef1141304e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -2,10 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -44,10 +45,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(yiChatCompletionPath) - _ = util.OverwriteRequestHost(yiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -55,28 +53,19 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, yiChatCompletionPath) + util.OverwriteRequestHostHeader(headers, yiDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *yiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, yiChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 98c824e4f3..40fbe4ef88 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -2,10 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -43,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) - _ = util.OverwriteRequestHost(zhipuAiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx)) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -54,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath) + util.OverwriteRequestHostHeader(headers, zhipuAiDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *zhipuAiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, zhipuAiChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index edaf60506b..f0d4c0ce7c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -25,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string { return headers } -func OverwriteRequestHost(host string) error { - if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { - _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost) - } - return proxywasm.ReplaceHttpRequestHeader(":authority", host) -} - func OverwriteRequestPath(path string) error { if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath) @@ -48,21 +41,21 @@ func OverwriteRequestAuthorization(credential string) error { return proxywasm.ReplaceHttpRequestHeader("Authorization", credential) } -func OverwriteHttpRequestHost(headers http.Header, host string) { +func OverwriteRequestHostHeader(headers http.Header, host string) { if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { headers.Set("X-ENVOY-ORIGINAL-HOST", originHost) } headers.Set(":authority", host) } -func OverwriteHttpRequestPath(headers http.Header, path string) { +func OverwriteRequestPathHeader(headers http.Header, path string) { if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) } headers.Set(":path", path) } -func OverwriteHttpRequestAuthorization(headers http.Header, credential string) { +func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { if originAuth := headers.Get("Authorization"); originAuth != "" { headers.Set("X-HI-ORIGINAL-AUTH", originAuth) From 51c766fa7ced7dd211457e391f572a50a1686fff Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 14 Nov 2024 11:18:30 +0800 Subject: [PATCH 30/31] fix --- plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go | 5 +++-- plugins/wasm-go/extensions/ai-proxy/provider/provider.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 3ed477602c..e016dc4553 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) // baichuanProvider is the provider for baichuan Ai service. diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index e1080cf7ce..41160464ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -3,13 +3,13 @@ package provider import ( "encoding/json" "errors" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "math/rand" "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" ) From 0b3422a1622867f9ca74bd6d1820959978b83bf3 Mon Sep 17 00:00:00 2001 From: Se7en Date: Thu, 14 Nov 2024 23:09:50 +0800 Subject: [PATCH 31/31] make GetApiName optional --- plugins/wasm-go/extensions/ai-proxy/main.go | 19 ++++++++++++++----- .../extensions/ai-proxy/provider/ai360.go | 11 +---------- .../extensions/ai-proxy/provider/azure.go | 12 ------------ .../extensions/ai-proxy/provider/baichuan.go | 8 -------- .../ai-proxy/provider/cloudflare.go | 12 ++---------- .../extensions/ai-proxy/provider/deepseek.go | 8 -------- .../extensions/ai-proxy/provider/mistral.go | 11 +---------- .../extensions/ai-proxy/provider/moonshot.go | 11 +---------- .../extensions/ai-proxy/provider/ollama.go | 8 -------- .../extensions/ai-proxy/provider/openai.go | 4 ---- .../extensions/ai-proxy/provider/provider.go | 13 +++---------- .../extensions/ai-proxy/provider/spark.go | 7 ------- .../extensions/ai-proxy/provider/stepfun.go | 11 +---------- .../extensions/ai-proxy/provider/yi.go | 8 -------- 14 files changed, 23 insertions(+), 120 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 3a29575c21..6c83dd3901 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -80,13 +80,12 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - - var apiName provider.ApiName + apiName := getOpenAiApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { - apiName = activeProvider.GetApiName(path.Path) - } else { - apiName = provider.GetOpenAiApiName(path.Path) + if handler, ok := activeProvider.(provider.ApiNameHandler); ok { + apiName = handler.GetApiName(path.Path) + } } if apiName == "" { @@ -263,3 +262,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) { (*ctx).BufferResponseBody() } } + +func getOpenAiApiName(path string) provider.ApiName { + if strings.HasSuffix(path, "/v1/chat/completions") { + return provider.ApiNameChatCompletion + } + if strings.HasSuffix(path, "/v1/embeddings") { + return provider.ApiNameEmbeddings + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 439fd8ff10..6f42d570d0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -3,7 +3,6 @@ package provider import ( "errors" "net/http" - "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -12,8 +11,7 @@ import ( // ai360Provider is the provider for 360 OpenAI service. const ( - ai360Domain = "api.360.cn" - ai360ChatCompletionPath = "/v1/chat/completions" + ai360Domain = "api.360.cn" ) type ai360ProviderInitializer struct { @@ -64,10 +62,3 @@ func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName headers.Del("Accept-Encoding") headers.Del("Content-Length") } - -func (m *ai360Provider) GetApiName(path string) ApiName { - if strings.Contains(path, ai360ChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 9919aeb073..1a79908d4e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -5,17 +5,12 @@ import ( "fmt" "net/http" "net/url" - "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) -const ( - azureChatCompletionPath = "/chat/completions" -) - // azureProvider is the provider for Azure OpenAI service. type azureProviderInitializer struct { } @@ -79,10 +74,3 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *azureProvider) GetApiName(path string) ApiName { - if strings.Contains(path, azureChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index e016dc4553..b43ba8ee26 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -3,7 +3,6 @@ package provider import ( "errors" "net/http" - "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -64,10 +63,3 @@ func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *baichuanProvider) GetApiName(path string) ApiName { - if strings.Contains(path, baichuanChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index a4c02381fa..2f6108b0df 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -13,8 +13,7 @@ import ( const ( cloudflareDomain = "api.cloudflare.com" // https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/ - cloudflareChatCompletionPath = "/v1/chat/completions" - cloudflareChatCompletionFullPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" + cloudflareChatCompletionPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" ) type cloudflareProviderInitializer struct { @@ -59,16 +58,9 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN } func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionFullPath, "{account_id}", c.config.cloudflareAccountId, 1)) + util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) util.OverwriteRequestHostHeader(headers, cloudflareDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) headers.Del("Accept-Encoding") headers.Del("Content-Length") } - -func (c *cloudflareProvider) GetApiName(path string) ApiName { - if strings.Contains(path, cloudflareChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index c1eb57fe45..9cad3928f5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -6,7 +6,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "net/http" - "strings" ) // deepseekProvider is the provider for deepseek Ai service. @@ -63,10 +62,3 @@ func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *deepseekProvider) GetApiName(path string) ApiName { - if strings.Contains(path, deepseekChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 23b870cff0..3e5323a60c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -6,12 +6,10 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "net/http" - "strings" ) const ( - mistralDomain = "api.mistral.ai" - mistralChatCompletionPath = "/v1/chat/completions" + mistralDomain = "api.mistral.ai" ) type mistralProviderInitializer struct{} @@ -59,10 +57,3 @@ func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *mistralProvider) GetApiName(path string) ApiName { - if strings.Contains(path, mistralChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index de40471c92..cb914d8c85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -3,14 +3,12 @@ package provider import ( "errors" "fmt" - "net/http" - "strings" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "net/http" ) // moonshotProvider is the provider for Moonshot AI service. @@ -151,10 +149,3 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba return errors.New("unsupported method: " + method) } } - -func (m *moonshotProvider) GetApiName(path string) ApiName { - if strings.Contains(path, moonshotChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 3f1303d750..5339083819 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -7,7 +7,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "net/http" - "strings" ) // ollamaProvider is the provider for Ollama service. @@ -69,10 +68,3 @@ func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam util.OverwriteRequestHostHeader(headers, m.serviceDomain) headers.Del("Content-Length") } - -func (m *ollamaProvider) GetApiName(path string) ApiName { - if strings.Contains(path, ollamaChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index ab92191dfe..60c835cd49 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -112,7 +112,3 @@ func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A } return json.Marshal(request) } - -func (m *openaiProvider) GetApiName(path string) ApiName { - return GetOpenAiApiName(path) -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 41160464ac..1620ff10a8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -109,6 +109,9 @@ var ( type Provider interface { GetProviderType() string +} + +type ApiNameHandler interface { GetApiName(path string) ApiName } @@ -537,13 +540,3 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap } return json.Marshal(request) } - -func GetOpenAiApiName(path string) ApiName { - if strings.HasSuffix(path, "/v1/chat/completions") { - return ApiNameChatCompletion - } - if strings.HasSuffix(path, "/v1/embeddings") { - return ApiNameEmbeddings - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index e39bdaded9..c2e013643c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -180,10 +180,3 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName headers.Del("Accept-Encoding") headers.Del("Content-Length") } - -func (p *sparkProvider) GetApiName(path string) ApiName { - if strings.Contains(path, sparkChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index f96e59e65b..1ee01abe62 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -2,12 +2,10 @@ package provider import ( "errors" - "net/http" - "strings" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" ) const ( @@ -62,10 +60,3 @@ func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *stepfunProvider) GetApiName(path string) ApiName { - if strings.Contains(path, stepfunChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index ef1141304e..7cb05a9388 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -3,7 +3,6 @@ package provider import ( "errors" "net/http" - "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -62,10 +61,3 @@ func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName Ap util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } - -func (m *yiProvider) GetApiName(path string) ApiName { - if strings.Contains(path, yiChatCompletionPath) { - return ApiNameChatCompletion - } - return "" -}