diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md
index 2a8a838b1d..9d7dbc796b 100644
--- a/plugins/wasm-go/extensions/ai-proxy/README.md
+++ b/plugins/wasm-go/extensions/ai-proxy/README.md
@@ -31,15 +31,16 @@ description: AI 代理插件配置参考
`provider`的配置字段说明如下:
-| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
-| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `type` | string | 必填 | - | AI 服务提供商名称 |
-| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
-| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
-| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
-| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
-| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
-| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
+| `type` | string | 必填 | - | AI 服务提供商名称 |
+| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
+| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
+| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
+| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
+| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
+| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
+| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
`context`的配置字段说明如下:
@@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
+`failover` 的配置字段说明如下:
+
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+|------------------|--------|------|-------|-----------------------------|
+| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
+| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
+| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
+| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
+| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
+| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
### 提供商特有配置
diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go
index b545271a70..48f08dd9e4 100644
--- a/plugins/wasm-go/extensions/ai-proxy/config/config.go
+++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go
@@ -1,9 +1,9 @@
package config
import (
- "github.com/tidwall/gjson"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
+ "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/tidwall/gjson"
)
// @Name ai-proxy
@@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
return nil
}
-func (c *PluginConfig) Complete() error {
+func (c *PluginConfig) Complete(log wrapper.Log) error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)
+
+ providerConfig := c.GetProviderConfig()
+ err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
+
return err
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go
index 7b19d03fc2..7527fb8e4b 100644
--- a/plugins/wasm-go/extensions/ai-proxy/main.go
+++ b/plugins/wasm-go/extensions/ai-proxy/main.go
@@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
- if err := pluginConfig.Complete(); err != nil {
+ if err := pluginConfig.Complete(log); err != nil {
return err
}
+
return nil
}
@@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
- if err := pluginConfig.Complete(); err != nil {
+ if err := pluginConfig.Complete(log); err != nil {
return err
}
+
return nil
}
@@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
- if apiName == "" && !providerConfig.IsOriginal() {
+ if providerConfig.IsOriginal() {
+ if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
+ apiName = handler.GetApiName(path.Path)
+ }
+ }
+
+ if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
@@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
- // Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
+ // Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
+ // Set the apiToken for the current request.
+ providerConfig.SetApiTokenInUse(ctx, log)
+
hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
@@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}
+
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
@@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
+ providerConfig := pluginConfig.GetProviderConfig()
+ apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
+
status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
+ providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
+
return types.ActionContinue
}
+ // Reset ctxApiTokenRequestFailureCount if the request is successful,
+ // the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
+ providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
+
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
@@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}
-func getOpenAiApiName(path string) provider.ApiName {
- if strings.HasSuffix(path, "/v1/chat/completions") {
- return provider.ApiNameChatCompletion
- }
- if strings.HasSuffix(path, "/v1/embeddings") {
- return provider.ApiNameEmbeddings
- }
- return ""
-}
-
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
@@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
(*ctx).BufferResponseBody()
}
}
+
+func getOpenAiApiName(path string) provider.ApiName {
+ if strings.HasSuffix(path, "/v1/chat/completions") {
+ return provider.ApiNameChatCompletion
+ }
+ if strings.HasSuffix(path, "/v1/embeddings") {
+ return provider.ApiNameEmbeddings
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
index 00443fcf5e..6f42d570d0 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
@@ -1,14 +1,12 @@
package provider
import (
- "encoding/json"
"errors"
- "fmt"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// ai360Provider is the provider for 360 OpenAI service.
@@ -46,10 +44,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(ai360Domain)
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -58,47 +53,12 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
- if apiName == ApiNameChatCompletion {
- return m.onChatCompletionRequestBody(ctx, body, log)
- }
- if apiName == ApiNameEmbeddings {
- return m.onEmbeddingsRequestBody(ctx, body, log)
- }
- return types.ActionContinue, errUnsupportedApiName
-}
-
-func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- // 映射模型
- mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- request.Model = mappedModel
- return types.ActionContinue, replaceJsonRequestBody(request, log)
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
-func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- request := &embeddingsRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("missing model in embeddings request")
- }
- // 映射模型
- mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- request.Model = mappedModel
- return types.ActionContinue, replaceJsonRequestBody(request, log)
+func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, ai360Domain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
index 2dcba2f8ff..1a79908d4e 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
@@ -3,16 +3,15 @@ package provider
import (
"errors"
"fmt"
+ "net/http"
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// azureProvider is the provider for Azure OpenAI service.
-
type azureProviderInitializer struct {
}
@@ -55,47 +54,23 @@ func (m *azureProvider) GetProviderType() string {
}
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
- _ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
- _ = util.OverwriteRequestHost(m.serviceUrl.Host)
- _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
- if apiName == ApiNameChatCompletion {
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- } else {
- ctx.DontReadRequestBody()
+ if apiName != ApiNameChatCompletion {
+ return types.ActionContinue, errUnsupportedApiName
}
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
- // We don't need to process the request body for other APIs.
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+ return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- return types.ActionContinue, nil
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
+ util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
+ util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
index c16a8e4395..b43ba8ee26 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
@@ -2,11 +2,10 @@ package provider
import (
"errors"
- "fmt"
+ "net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,10 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(baichuanChatCompletionPath)
- _ = util.OverwriteRequestHost(baichuanDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +54,12 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, baichuanDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
index fc779d5306..42a1bc723d 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
"strings"
"time"
@@ -16,7 +17,8 @@ import (
// baiduProvider is the provider for baidu ernie bot service.
const (
- baiduDomain = "aip.baidubce.com"
+ baiduDomain = "aip.baidubce.com"
+ baiduChatCompletionPath = "/chat"
)
var baiduModelToPathSuffixMap = map[string]string{
@@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(baiduDomain)
-
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
+ b.config.handleRequestHeaders(b, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
+func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, baiduDomain)
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
+}
+
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- // 使用文心一言接口协议
- if b.config.protocol == protocolOriginal {
- request := &baiduTextGenRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("request model is empty")
- }
- // 根据模型重写requestPath
- path := b.getRequestPath(request.Model)
- _ = util.OverwriteRequestPath(path)
-
- if b.config.context == nil {
- return types.ActionContinue, nil
- }
+ return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
+}
- err := b.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
-
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- b.setSystemContent(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
- }
+func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+ err := b.config.parseRequestAndMapModel(ctx, request, body, log)
+ if err != nil {
+ return nil, err
}
+ path := b.getRequestPath(ctx, request.Model)
+ util.OverwriteRequestPathHeader(headers, path)
- // 映射模型重写requestPath
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, b.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
- path := b.getRequestPath(mappedModel)
- _ = util.OverwriteRequestPath(path)
-
- if b.config.context == nil {
- baiduRequest := b.baiduTextGenRequest(request)
- return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
- }
-
- err := b.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- baiduRequest := b.baiduTextGenRequest(request)
- if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ baiduRequest := b.baiduTextGenRequest(request)
+ return json.Marshal(baiduRequest)
}
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -226,13 +165,13 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}
-func (b *baiduProvider) getRequestPath(baiduModel string) string {
+func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
- return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
+ return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
}
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
@@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
+
+func (b *baiduProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, baiduChatCompletionPath) {
+ return ApiNameChatCompletion
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
index 7bbbc93d79..8b98d62d64 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
"strings"
"time"
@@ -105,102 +106,39 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
+ c.config.handleRequestHeaders(c, ctx, apiName, log)
+ return types.ActionContinue, nil
+}
+
+func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, claudeDomain)
- _ = util.OverwriteRequestPath(claudeChatCompletionPath)
- _ = util.OverwriteRequestHost(claudeDomain)
- _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
+ headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}
- _ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- return types.ActionContinue, nil
+ headers.Add("anthropic-version", c.config.claudeVersion)
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
+ return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
+}
- // use original protocol
- if c.config.protocol == protocolOriginal {
- if c.config.context == nil {
- return types.ActionContinue, nil
- }
-
- request := &claudeTextGenRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
-
- err := c.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
-
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
- }
-
- // use openai protocol
+func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
-
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
+ if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
+ return nil, err
}
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, c.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
-
- streaming := request.Stream
- if streaming {
- _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
- }
-
- if c.config.context == nil {
- claudeRequest := c.buildClaudeTextGenRequest(request)
- return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
- }
-
- err := c.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- claudeRequest := c.buildClaudeTextGenRequest(request)
- if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ claudeRequest := c.buildClaudeTextGenRequest(request)
+ return json.Marshal(claudeRequest)
}
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -369,3 +307,25 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
+
+func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
+ request := &claudeTextGenRequest{}
+ if err := json.Unmarshal(body, request); err != nil {
+ return nil, fmt.Errorf("unable to unmarshal request: %v", err)
+ }
+
+ if request.System == "" {
+ request.System = content
+ } else {
+ request.System = content + "\n" + request.System
+ }
+
+ return json.Marshal(request)
+}
+
+func (c *claudeProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, claudeChatCompletionPath) {
+ return ApiNameChatCompletion
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
index 35f6f2dc78..2f6108b0df 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
@@ -2,12 +2,11 @@ package provider
import (
"errors"
- "fmt"
+ "net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,13 +46,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
- _ = util.OverwriteRequestHost(cloudflareDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken())
-
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
+ c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -61,49 +54,13 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
+ return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
+}
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, c.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
-
- streaming := request.Stream
- if streaming {
- _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
- }
-
- if c.contextCache == nil {
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- return types.ActionContinue, nil
- }
- err := c.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
+ util.OverwriteRequestHostHeader(headers, cloudflareDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
index 7ffe1708af..72dbaf280b 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
@@ -3,17 +3,16 @@ package provider
import (
"encoding/json"
"errors"
- "fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
+ "strings"
)
const (
- cohereDomain = "api.cohere.com"
- chatCompletionPath = "/v1/chat"
+ cohereDomain = "api.cohere.com"
+ cohereChatCompletionPath = "/v1/chat"
)
type cohereProviderInitializer struct{}
@@ -27,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &cohereProvider{
- config: config,
+ config: config,
+ contextCache: createContextCache(&config),
}, nil
}
type cohereProvider struct {
- config ProviderConfig
+ config ProviderConfig
+ contextCache *contextCache
}
type cohereTextGenRequest struct {
@@ -57,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(cohereDomain)
- _ = util.OverwriteRequestPath(chatCompletionPath)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -68,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.config.protocol == protocolOriginal {
- request := &cohereTextGenRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- return m.handleRequestBody(log, request)
- }
- origin := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, origin); err != nil {
- return types.ActionContinue, err
- }
- request := m.buildCohereRequest(origin)
- return m.handleRequestBody(log, request)
-}
-
-func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- err := replaceJsonRequestBody(request, log)
- if err != nil {
- _ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
@@ -112,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
PresencePenalty: origin.PresencePenalty,
}
}
+
+func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, cohereDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
+func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
+ request := &chatCompletionRequest{}
+ if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
+ return nil, err
+ }
+
+ cohereRequest := m.buildCohereRequest(request)
+ return json.Marshal(cohereRequest)
+}
+
+func (m *cohereProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, cohereChatCompletionPath) {
+ return ApiNameChatCompletion
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/context.go b/plugins/wasm-go/extensions/ai-proxy/provider/context.go
index 2026a9818a..d9fe2e26c4 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/context.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/context.go
@@ -1,12 +1,15 @@
package provider
import (
+ "encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
+ "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
)
@@ -57,6 +60,10 @@ type contextCache struct {
content string
}
+type ContextInserter interface {
+ insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
+}
+
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
if callback == nil {
return errors.New("callback is nil")
@@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
timeout: providerConfig.timeout,
}
}
+
+func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
+ // get context will overwrite the original request host and path
+ // save the original request host and path in case they are needed for apiToken health check
+ ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost())
+ ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath())
+
+ if c.loaded {
+ log.Debugf("context file loaded from cache")
+ insertContext(provider, c.content, nil, body, log)
+ return nil
+ }
+
+ log.Infof("loading context file from %s", c.fileUrl.String())
+ return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
+ if statusCode != http.StatusOK {
+ insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
+ return
+ }
+ c.content = string(responseBody)
+ c.loaded = true
+ log.Debugf("content: %s", c.content)
+ insertContext(provider, c.content, nil, body, log)
+ }, c.timeout)
+}
+
+func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
+ defer func() {
+ _ = proxywasm.ResumeHttpRequest()
+ }()
+
+ typ := provider.GetProviderType()
+ if err != nil {
+ log.Errorf("failed to load context file: %v", err)
+ _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
+ }
+
+ if inserter, ok := provider.(ContextInserter); ok {
+ body, err = inserter.insertHttpContextMessage(body, content, false)
+ } else {
+ body, err = defaultInsertHttpContextMessage(body, content)
+ }
+
+ if err != nil {
+ _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
+ }
+ if err := replaceHttpJsonRequestBody(body, log); err != nil {
+ _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
+ }
+}
+
+func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
+ request := &chatCompletionRequest{}
+ if err := json.Unmarshal(body, request); err != nil {
+ return nil, fmt.Errorf("unable to unmarshal request: %v", err)
+ }
+
+ fileMessage := chatMessage{
+ Role: roleSystem,
+ Content: content,
+ }
+ var firstNonSystemMessageIndex int
+ for i, message := range request.Messages {
+ if message.Role != roleSystem {
+ firstNonSystemMessageIndex = i
+ break
+ }
+ }
+ if firstNonSystemMessageIndex == 0 {
+ request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
+ } else {
+ request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
+ }
+
+ return json.Marshal(request)
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
index 924746c8c9..bafe6b3dde 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
+ "strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(deeplChatCompletionPath)
- _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
+ d.config.handleRequestHeaders(d, ctx, apiName, log)
return types.HeaderStopIteration, nil
}
+func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
+ util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+ headers.Del("Accept-Encoding")
+}
+
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if d.config.protocol == protocolOriginal {
- request := &deeplRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if err := d.overwriteRequestHost(request.Model); err != nil {
- return types.ActionContinue, err
- }
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- } else {
- originRequest := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, originRequest); err != nil {
- return types.ActionContinue, err
- }
- if err := d.overwriteRequestHost(originRequest.Model); err != nil {
- return types.ActionContinue, err
- }
- ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model)
- deeplRequest := &deeplRequest{
- Text: make([]string, 0),
- TargetLang: d.config.targetLang,
- }
- for _, msg := range originRequest.Messages {
- if msg.Role == roleSystem {
- deeplRequest.Context = msg.StringContent()
- } else {
- deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
- }
- }
- return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log)
+ return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
+}
+
+func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ request := &chatCompletionRequest{}
+ if err := decodeChatCompletionRequest(body, request); err != nil {
+ return nil, err
}
+ ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
+
+ err := d.overwriteRequestHost(headers, request.Model)
+ if err != nil {
+ return nil, err
+ }
+
+ baiduRequest := d.deeplTextGenRequest(request)
+ return json.Marshal(baiduRequest)
}
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo
}
}
-func (d *deeplProvider) overwriteRequestHost(model string) error {
+func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error {
if model == "Pro" {
- _ = util.OverwriteRequestHost(deeplHostPro)
+ util.OverwriteRequestHostHeader(headers, deeplHostPro)
} else if model == "Free" {
- _ = util.OverwriteRequestHost(deeplHostFree)
+ util.OverwriteRequestHostHeader(headers, deeplHostFree)
} else {
return errors.New(`deepl model should be "Free" or "Pro"`)
}
return nil
}
+
+func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest {
+ deeplRequest := &deeplRequest{
+ Text: make([]string, 0),
+ TargetLang: d.config.targetLang,
+ }
+ for _, msg := range request.Messages {
+ if msg.Role == roleSystem {
+ deeplRequest.Context = msg.StringContent()
+ } else {
+ deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
+ }
+ }
+ return deeplRequest
+}
+
+func (d *deeplProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, deeplChatCompletionPath) {
+ return ApiNameChatCompletion
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
index 8cb71462d2..9cad3928f5 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
@@ -2,12 +2,10 @@ package provider
import (
"errors"
- "fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
)
// deepseekProvider is the provider for deepseek Ai service.
@@ -47,10 +45,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(deepseekChatCompletionPath)
- _ = util.OverwriteRequestHost(deepseekDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +53,12 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, deepseekDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
index 0ca349a773..651b983206 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
@@ -2,12 +2,11 @@ package provider
import (
"errors"
- "fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
+ "strings"
)
const (
@@ -41,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string {
}
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
- _ = util.OverwriteRequestHost(doubaoDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- if m.config.protocol == protocolOriginal {
- ctx.DontReadRequestBody()
- return types.ActionContinue, nil
- }
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(doubaoChatCompletionPath)
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -59,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- if m.contextCache != nil {
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- } else {
- return types.ActionContinue, err
- }
- } else {
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- return types.ActionContinue, err
- }
- _ = proxywasm.ResumeHttpRequest()
- return types.ActionPause, nil
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, doubaoDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
+func (m *doubaoProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, doubaoChatCompletionPath) {
+ return ApiNameChatCompletion
}
+ return ""
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go
new file mode 100644
index 0000000000..32e92a4db4
--- /dev/null
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go
@@ -0,0 +1,594 @@
+package provider
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
+ "github.com/google/uuid"
+ "math/rand"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "github.com/tidwall/gjson"
+)
+
+type failover struct {
+ // @Title zh-CN 是否启用 apiToken 的 failover 机制
+ enabled bool `required:"true" yaml:"enabled" json:"enabled"`
+ // @Title zh-CN 触发 failover 连续请求失败的阈值
+ failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
+ // @Title zh-CN 健康检测的成功阈值
+ successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
+ // @Title zh-CN 健康检测的间隔时间,单位毫秒
+ healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
+ // @Title zh-CN 健康检测的超时时间,单位毫秒
+ healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
+ // @Title zh-CN 健康检测使用的模型
+ healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
+ // @Title zh-CN 本次请求使用的 apiToken
+ ctxApiTokenInUse string
+ // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
+ ctxApiTokenRequestFailureCount string
+ // @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数
+ ctxApiTokenRequestSuccessCount string
+ // @Title zh-CN 记录所有可用的 apiToken 列表
+ ctxApiTokens string
+ // @Title zh-CN 记录所有不可用的 apiToken 列表
+ ctxUnavailableApiTokens string
+ // @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求
+ ctxHealthCheckEndpoint string
+ // @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
+ ctxVmLease string
+}
+
+type Lease struct {
+ VMID string `json:"vmID"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+type HealthCheckEndpoint struct {
+ Host string `json:"host"`
+ Path string `json:"path"`
+ Cluster string `json:"cluster"`
+}
+
+const (
+ casMaxRetries = 10
+ addApiTokenOperation = "addApiToken"
+ removeApiTokenOperation = "removeApiToken"
+ addApiTokenRequestCountOperation = "addApiTokenRequestCount"
+ resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
+ ctxRequestHost = "requestHost"
+ ctxRequestPath = "requestPath"
+)
+
+var (
+ healthCheckClient wrapper.HttpClient
+)
+
+func (f *failover) FromJson(json gjson.Result) {
+ f.enabled = json.Get("enabled").Bool()
+ f.failureThreshold = json.Get("failureThreshold").Int()
+ if f.failureThreshold == 0 {
+ f.failureThreshold = 3
+ }
+ f.successThreshold = json.Get("successThreshold").Int()
+ if f.successThreshold == 0 {
+ f.successThreshold = 1
+ }
+ f.healthCheckInterval = json.Get("healthCheckInterval").Int()
+ if f.healthCheckInterval == 0 {
+ f.healthCheckInterval = 5000
+ }
+ f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
+ if f.healthCheckTimeout == 0 {
+ f.healthCheckTimeout = 5000
+ }
+ f.healthCheckModel = json.Get("healthCheckModel").String()
+}
+
+func (f *failover) Validate() error {
+ if f.healthCheckModel == "" {
+ return errors.New("missing healthCheckModel in failover config")
+ }
+ return nil
+}
+
+func (c *ProviderConfig) initVariable() {
+ // Set provider name as prefix to differentiate shared data
+ provider := c.GetType()
+ c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse"
+ c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount"
+ c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount"
+ c.failover.ctxApiTokens = provider + "-apiTokens"
+ c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens"
+ c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath"
+ c.failover.ctxVmLease = provider + "-vmLease"
+}
+
+func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
+ return nil
+}
+
+func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
+ c.initVariable()
+ // Reset shared data in case plugin configuration is updated
+ log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
+ c.resetSharedData()
+
+ if c.isFailoverEnabled() {
+ log.Debugf("ai-proxy plugin failover is enabled")
+
+ vmID := generateVMID()
+ err := c.initApiTokens()
+
+ if err != nil {
+ return fmt.Errorf("failed to init apiTokens: %v", err)
+ }
+
+ wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
+ // Only the Wasm VM that successfully acquires the lease will perform health check
+ if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
+ log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
+ unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
+ if err != nil {
+ log.Errorf("Failed to get unavailable tokens: %v", err)
+ return
+ }
+ if len(unavailableTokens) > 0 {
+ for _, apiToken := range unavailableTokens {
+ log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
+ healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
+ healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
+ Host: healthCheckEndpoint.Host,
+ Cluster: healthCheckEndpoint.Cluster,
+ })
+
+ ctx := createHttpContext()
+ ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
+
+ modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
+ if err != nil {
+ log.Errorf("Failed to transform request headers and body: %v", err)
+ }
+
+ // The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
+ err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
+ if statusCode == 200 {
+ c.handleAvailableApiToken(apiToken, log)
+ }
+ }, uint32(c.failover.healthCheckTimeout))
+ if err != nil {
+ log.Errorf("Failed to perform health check request: %v", err)
+ }
+ }
+ }
+ }
+ })
+ }
+ return nil
+}
+
+func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) {
+ originalHeaders := util.SliceToHeader(headers)
+ if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
+ handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log)
+ }
+
+ var err error
+ if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
+ body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
+ } else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
+ headers := util.GetOriginalHttpHeaders()
+ body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
+ util.ReplaceOriginalHttpHeaders(headers)
+ } else {
+ body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
+ }
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
+ }
+
+ modifiedHeaders := util.HeaderToSlice(originalHeaders)
+ return modifiedHeaders, body, nil
+}
+
+func createHttpContext() *wrapper.CommonHttpCtx[any] {
+ setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
+ vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
+ pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
+ ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
+ return ctx
+}
+
+func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
+ data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
+ if err != nil {
+ log.Errorf("Failed to get request host and path: %v", err)
+ }
+ var healthCheckEndpoint HealthCheckEndpoint
+ err = json.Unmarshal(data, &healthCheckEndpoint)
+ if err != nil {
+ log.Errorf("Failed to unmarshal request host and path: %v", err)
+ }
+
+ headers := [][2]string{
+ {"content-type", "application/json"},
+ }
+ body := []byte(fmt.Sprintf(`{
+ "model": "%s",
+ "messages": [
+ {
+ "role": "user",
+ "content": "who are you?"
+ }
+ ]
+ }`, c.failover.healthCheckModel))
+ return healthCheckEndpoint, headers, body
+}
+
+func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
+ now := time.Now().Unix()
+
+ data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
+ if err != nil {
+ if errors.Is(err, types.ErrorStatusNotFound) {
+ return c.setLease(vmID, now, cas, log)
+ } else {
+ log.Errorf("Failed to get lease: %v", err)
+ return false
+ }
+ }
+ if data == nil {
+ return c.setLease(vmID, now, cas, log)
+ }
+
+ var lease Lease
+ err = json.Unmarshal(data, &lease)
+ if err != nil {
+ log.Errorf("Failed to unmarshal lease data: %v", err)
+ return false
+ }
+ // If vmID is itself, try to renew the lease directly
+ // If the lease is expired (60s), try to acquire the lease
+ if lease.VMID == vmID || now-lease.Timestamp > 60 {
+ lease.VMID = vmID
+ lease.Timestamp = now
+ return c.setLease(vmID, now, cas, log)
+ }
+
+ return false
+}
+
+func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
+ lease := Lease{
+ VMID: vmID,
+ Timestamp: timestamp,
+ }
+ leaseByte, err := json.Marshal(lease)
+ if err != nil {
+ log.Errorf("Failed to marshal lease data: %v", err)
+ return false
+ }
+
+ if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
+ log.Errorf("Failed to set or renew lease: %v", err)
+ return false
+ }
+ return true
+}
+
+func generateVMID() string {
+ return uuid.New().String()
+}
+
+// When number of request successes exceeds the threshold during health check,
+// add the apiToken back to the available list and remove it from the unavailable list
+func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
+ successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
+ if err != nil {
+ log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
+ return
+ }
+
+ successCount := successApiTokenRequestCount[apiToken] + 1
+ if successCount >= c.failover.successThreshold {
+ log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken)
+ removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
+ addApiToken(c.failover.ctxApiTokens, apiToken, log)
+ resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
+ } else {
+ log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
+ addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
+ }
+}
+
+// When number of request failures exceeds the threshold,
+// remove the apiToken from the available list and add it to the unavailable list
+func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
+ failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
+ if err != nil {
+ log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
+ return
+ }
+
+ availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
+ if err != nil {
+ log.Errorf("Failed to get available apiToken: %v", err)
+ return
+ }
+ // unavailable apiToken has been removed from the available list
+ if !containsElement(availableTokens, apiToken) {
+ return
+ }
+
+ failureCount := failureApiTokenRequestCount[apiToken] + 1
+ if failureCount >= c.failover.failureThreshold {
+ log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
+ removeApiToken(c.failover.ctxApiTokens, apiToken, log)
+ addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
+ resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
+ // Set the request host and path to shared data in case they are needed in apiToken health check
+ c.setHealthCheckEndpoint(ctx, log)
+ } else {
+ log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
+ addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
+ }
+}
+
+func addApiToken(key, apiToken string, log wrapper.Log) {
+ modifyApiToken(key, apiToken, addApiTokenOperation, log)
+}
+
+func removeApiToken(key, apiToken string, log wrapper.Log) {
+ modifyApiToken(key, apiToken, removeApiTokenOperation, log)
+}
+
+func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
+ for attempt := 1; attempt <= casMaxRetries; attempt++ {
+ apiTokens, cas, err := getApiTokens(key)
+ if err != nil {
+ log.Errorf("Failed to get %s: %v", key, err)
+ continue
+ }
+
+ exists := containsElement(apiTokens, apiToken)
+ if op == addApiTokenOperation && exists {
+ log.Debugf("%s already exists in %s", apiToken, key)
+ return
+ } else if op == removeApiTokenOperation && !exists {
+ log.Debugf("%s does not exist in %s", apiToken, key)
+ return
+ }
+
+ if op == addApiTokenOperation {
+ apiTokens = append(apiTokens, apiToken)
+ } else {
+ apiTokens = removeElement(apiTokens, apiToken)
+ }
+
+ if err := setApiTokens(key, apiTokens, cas); err == nil {
+ log.Debugf("Successfully updated %s in %s", apiToken, key)
+ return
+ } else if !errors.Is(err, types.ErrorStatusCasMismatch) {
+ log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
+ return
+ }
+
+ log.Errorf("CAS mismatch when setting %s, retrying...", key)
+ }
+}
+
+func getApiTokens(key string) ([]string, uint32, error) {
+ data, cas, err := proxywasm.GetSharedData(key)
+ if err != nil {
+ if errors.Is(err, types.ErrorStatusNotFound) {
+ return []string{}, cas, nil
+ }
+ return nil, 0, err
+ }
+ if data == nil {
+ return []string{}, cas, nil
+ }
+
+ var apiTokens []string
+ if err = json.Unmarshal(data, &apiTokens); err != nil {
+ return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
+ }
+
+ return apiTokens, cas, nil
+}
+
+func setApiTokens(key string, apiTokens []string, cas uint32) error {
+ data, err := json.Marshal(apiTokens)
+ if err != nil {
+ return fmt.Errorf("failed to marshal tokens: %v", err)
+ }
+ return proxywasm.SetSharedData(key, data, cas)
+}
+
+func removeElement(slice []string, s string) []string {
+ for i := 0; i < len(slice); i++ {
+ if slice[i] == s {
+ slice = append(slice[:i], slice[i+1:]...)
+ i--
+ }
+ }
+ return slice
+}
+
+func containsElement(slice []string, s string) bool {
+ for _, item := range slice {
+ if item == s {
+ return true
+ }
+ }
+ return false
+}
+
+func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
+ data, cas, err := proxywasm.GetSharedData(key)
+ if err != nil {
+ if errors.Is(err, types.ErrorStatusNotFound) {
+ return make(map[string]int64), cas, nil
+ }
+ return nil, 0, err
+ }
+
+ if data == nil {
+ return make(map[string]int64), cas, nil
+ }
+
+ var apiTokens map[string]int64
+ err = json.Unmarshal(data, &apiTokens)
+ if err != nil {
+ return nil, 0, err
+ }
+ return apiTokens, cas, nil
+}
+
+func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
+ modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
+}
+
+func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
+ modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
+}
+
+func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
+ if c.isFailoverEnabled() {
+ failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
+ if err != nil {
+ log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
+ }
+ if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
+ log.Infof("reset apiToken %s request failure count", apiTokenInUse)
+ resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
+ }
+ }
+}
+
+func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
+ for attempt := 1; attempt <= casMaxRetries; attempt++ {
+ apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
+ if err != nil {
+ log.Errorf("Failed to get %s: %v", key, err)
+ continue
+ }
+
+ if op == resetApiTokenRequestCountOperation {
+ delete(apiTokenRequestCount, apiToken)
+ } else {
+ apiTokenRequestCount[apiToken]++
+ }
+
+ apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
+ if err != nil {
+ log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
+ }
+
+ if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
+ log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
+ return
+ } else if !errors.Is(err, types.ErrorStatusCasMismatch) {
+ log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
+ return
+ }
+
+ log.Errorf("CAS mismatch when setting %s, retrying...", key)
+ }
+}
+
+func (c *ProviderConfig) initApiTokens() error {
+ return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
+}
+
+func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
+ apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
+ unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
+ log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
+
+ if err != nil {
+ return ""
+ }
+ count := len(apiTokens)
+ switch count {
+ case 0:
+ return ""
+ case 1:
+ return apiTokens[0]
+ default:
+ return apiTokens[rand.Intn(count)]
+ }
+}
+
+func (c *ProviderConfig) isFailoverEnabled() bool {
+ return c.failover.enabled
+}
+
+func (c *ProviderConfig) resetSharedData() {
+ _ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
+ _ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
+ _ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
+ _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
+ _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
+}
+
+func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
+ if c.isFailoverEnabled() {
+ c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
+ }
+}
+
+func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
+ return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
+}
+
+func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
+ var apiToken string
+ if c.isFailoverEnabled() {
+ // if enable apiToken failover, only use available apiToken
+ apiToken = c.GetGlobalRandomToken(log)
+ } else {
+ apiToken = c.GetRandomToken()
+ }
+ log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
+ ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
+}
+
+func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
+ cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
+ if err != nil {
+ log.Errorf("Failed to get cluster_name: %v", err)
+ }
+
+ host := wrapper.GetRequestHost()
+ if host == "" {
+ host = ctx.GetContext(ctxRequestHost).(string)
+ }
+ path := wrapper.GetRequestPath()
+ if path == "" {
+ path = ctx.GetContext(ctxRequestPath).(string)
+ }
+
+ healthCheckEndpoint := HealthCheckEndpoint{
+ Host: host,
+ Path: path,
+ Cluster: string(cluster),
+ }
+
+ healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
+ if err != nil {
+ log.Errorf("Failed to marshal request host and path: %v", err)
+
+ }
+ err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
+ if err != nil {
+ log.Errorf("Failed to set request host and path: %v", err)
+ }
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
index 0d418c16a5..a4c1ef2cd9 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
"strings"
"time"
@@ -17,8 +18,11 @@ import (
// geminiProvider is the provider for google gemini/gemini flash service.
const (
- geminiApiKeyHeader = "x-goog-api-key"
- geminiDomain = "generativelanguage.googleapis.com"
+ geminiApiKeyHeader = "x-goog-api-key"
+ geminiDomain = "generativelanguage.googleapis.com"
+ geminiChatCompletionPath = "generateContent"
+ geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
+ geminiEmbeddingPath = "batchEmbedContents"
)
type geminiProviderInitializer struct {
@@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
-
- _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken())
- _ = util.OverwriteRequestHost(geminiDomain)
-
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
+ g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
-func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
- if apiName == ApiNameChatCompletion {
- return g.onChatCompletionRequestBody(ctx, body, log)
- } else if apiName == ApiNameEmbeddings {
- return g.onEmbeddingsRequestBody(ctx, body, log)
- }
- return types.ActionContinue, errUnsupportedApiName
+func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, geminiDomain)
+ headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
-func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- // 使用gemini接口协议
- if g.config.protocol == protocolOriginal {
- request := &geminiChatRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("request model is empty")
- }
- // 根据模型重写requestPath
- path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
- _ = util.OverwriteRequestPath(path)
-
- // 移除多余的model和stream字段
- request = &geminiChatRequest{
- Contents: request.Contents,
- SafetySettings: request.SafetySettings,
- GenerationConfig: request.GenerationConfig,
- Tools: request.Tools,
- }
- if g.config.context == nil {
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- }
-
- err := g.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
-
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- g.setSystemContent(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
+ if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
+ return types.ActionContinue, errUnsupportedApiName
}
+ return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
+}
- // 映射模型重写requestPath
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, g.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
+func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ if apiName == ApiNameChatCompletion {
+ return g.onChatCompletionRequestBody(ctx, body, headers, log)
+ } else {
+ return g.onEmbeddingsRequestBody(ctx, body, headers, log)
}
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
- path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream)
- _ = util.OverwriteRequestPath(path)
+}
- if g.config.context == nil {
- geminiRequest := g.buildGeminiChatRequest(request)
- return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
+func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ request := &chatCompletionRequest{}
+ err := g.config.parseRequestAndMapModel(ctx, request, body, log)
+ if err != nil {
+ return nil, err
}
+ path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
+ util.OverwriteRequestPathHeader(headers, path)
- err := g.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- geminiRequest := g.buildGeminiChatRequest(request)
- if err := replaceJsonRequestBody(geminiRequest, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ geminiRequest := g.buildGeminiChatRequest(request)
+ return json.Marshal(geminiRequest)
}
-func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- // 使用gemini接口协议
- if g.config.protocol == protocolOriginal {
- request := &geminiBatchEmbeddingRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("request model is empty")
- }
- // 根据模型重写requestPath
- path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
- _ = util.OverwriteRequestPath(path)
-
- // 移除多余的model字段
- request = &geminiBatchEmbeddingRequest{
- Requests: request.Requests,
- }
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- }
+func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
+ if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
+ return nil, err
}
-
- // 映射模型重写requestPath
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in embeddings request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, g.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
- path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false)
- _ = util.OverwriteRequestPath(path)
+ path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
+ util.OverwriteRequestPathHeader(headers, path)
geminiRequest := g.buildBatchEmbeddingRequest(request)
- return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
+ return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
- action = "batchEmbedContents"
+ action = geminiEmbeddingPath
} else if stream {
- action = "streamGenerateContent?alt=sse"
+ action = geminiChatCompletionStreamPath
} else {
- action = "generateContent"
+ action = geminiChatCompletionPath
}
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
}
@@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini
func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
+
+func (g *geminiProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) {
+ return ApiNameChatCompletion
+ }
+ if strings.Contains(path, geminiEmbeddingPath) {
+ return ApiNameEmbeddings
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go
index 5ee51b2742..0a2b0c84de 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go
@@ -1,14 +1,12 @@
package provider
import (
- "encoding/json"
"errors"
- "fmt"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
+ "strings"
)
// githubProvider is the provider for GitHub OpenAI service.
@@ -48,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(githubDomain)
- if apiName == ApiNameChatCompletion {
- _ = util.OverwriteRequestPath(githubCompletionPath)
- }
- if apiName == ApiNameEmbeddings {
- _ = util.OverwriteRequestPath(githubEmbeddingPath)
- }
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -66,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
- if apiName == ApiNameChatCompletion {
- return m.onChatCompletionRequestBody(ctx, body, log)
- }
- if apiName == ApiNameEmbeddings {
- return m.onEmbeddingsRequestBody(ctx, body, log)
- }
- return types.ActionContinue, errUnsupportedApiName
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
-func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
+func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, githubDomain)
+ if apiName == ApiNameChatCompletion {
+ util.OverwriteRequestPathHeader(headers, githubCompletionPath)
}
- // 映射模型
- mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
+ if apiName == ApiNameEmbeddings {
+ util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- request.Model = mappedModel
- return types.ActionContinue, replaceJsonRequestBody(request, log)
+ util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
-func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- request := &embeddingsRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("missing model in embeddings request")
+func (m *githubProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, githubCompletionPath) {
+ return ApiNameChatCompletion
}
- // 映射模型
- mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
+ if strings.Contains(path, githubEmbeddingPath) {
+ return ApiNameEmbeddings
}
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- request.Model = mappedModel
- return types.ActionContinue, replaceJsonRequestBody(request, log)
+ return ""
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
index 644e450ee9..dfbd971261 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
@@ -2,11 +2,11 @@ package provider
import (
"errors"
- "fmt"
+ "net/http"
+ "strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -18,14 +18,14 @@ const (
type groqProviderInitializer struct{}
-func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
+func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
-func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
+func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &groqProvider{
config: config,
contextCache: createContextCache(&config),
@@ -37,47 +37,35 @@ type groqProvider struct {
contextCache *contextCache
}
-func (m *groqProvider) GetProviderType() string {
+func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}
-func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(groqChatCompletionPath)
- _ = util.OverwriteRequestHost(groqDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
}
-func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
+func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
+ return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
+}
+
+func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, groqChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, groqDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
+func (g *groqProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, groqChatCompletionPath) {
+ return ApiNameChatCompletion
}
- return types.ActionContinue, err
+ return ""
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
index 7640a380b3..99cb135db6 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
"strings"
"time"
@@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string {
}
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
- // log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId)
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
+ // Delay the header processing to allow changing streaming mode in OnRequestBody
+ return types.HeaderStopIteration, nil
+}
- _ = util.OverwriteRequestHost(hunyuanDomain)
- _ = util.OverwriteRequestPath(hunyuanRequestPath)
-
- // 添加hunyuan需要的自定义字段
- _ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction)
- _ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue)
+func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, hunyuanDomain)
+ util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
- // 删除一些字段
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ // 添加 hunyuan 需要的自定义字段
+ headers.Add(actionKey, hunyuanChatCompletionTCAction)
+ headers.Add(versionKey, versionValue)
- // Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
}
+// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳)
var timestamp int64 = time.Now().Unix()
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
- // log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp)
// 使用混元本身接口的协议
if m.config.protocol == protocolOriginal {
@@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
- // log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!")
model := request.Model
if model == "" {
@@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
string(body),
)
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
- // log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body))
-
- // // 打印所有的headers
- // headers, err2 := proxywasm.GetHttpRequestHeaders()
- // if err2 != nil {
- // log.Errorf("failed to get request headers: %v", err2)
- // } else {
- // // 迭代并打印所有请求头
- // for _, header := range headers {
- // log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1])
- // }
- // }
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
}
@@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, err
}
+// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
+func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ request := &chatCompletionRequest{}
+ err := m.config.parseRequestAndMapModel(ctx, request, body, log)
+ if err != nil {
+ return nil, err
+ }
+
+ hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
+
+ var timestamp int64 = time.Now().Unix()
+ _ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
+ // 根据确定好的payload进行签名:
+ body, _ = json.Marshal(hunyuanRequest)
+ authorizedValueNew := GetTC3Authorizationcode(
+ m.config.hunyuanAuthId,
+ m.config.hunyuanAuthKey,
+ timestamp,
+ hunyuanDomain,
+ hunyuanChatCompletionTCAction,
+ string(body),
+ )
+ util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew)
+ return json.Marshal(hunyuanRequest)
+}
+
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
@@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64,
// fmt.Println(curl)
return authorization
}
+
+func (m *hunyuanProvider) GetApiName(path string) ApiName {
+ return ApiNameChatCompletion
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
index ded72d7b51..00aa0f7254 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(minimaxDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
+func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, minimaxDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return m.handleRequestBodyByChatCompletionPro(body, log)
} else {
// 使用ChatCompletion v2接口
- return m.handleRequestBodyByChatCompletionV2(body, log)
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
}
+func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ return m.handleRequestBodyByChatCompletionV2(body, headers, log)
+}
+
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
- // 使用minimax接口协议
- if m.config.protocol == protocolOriginal {
- request := &minimaxChatCompletionV2Request{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("request model is empty")
- }
- // 根据模型重写requestPath
- if m.config.minimaxGroupId == "" {
- return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
- }
- _ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
-
- if m.config.context == nil {
- return types.ActionContinue, nil
- }
-
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
-
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- m.setBotSettings(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
- }
-
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
@@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
+ // 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
+ // 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
+ // minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
@@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
}
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
-func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
+func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+ return nil, err
}
// 映射模型重写requestPath
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
- _ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
-
- if m.contextCache == nil {
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- }
+ util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return body, nil
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
+
+func (m *minimaxProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) {
+ return ApiNameChatCompletion
+ }
+ return ""
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
index b217d8019e..3e5323a60c 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
@@ -2,12 +2,10 @@ package provider
import (
"errors"
- "fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
)
const (
@@ -43,9 +41,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(mistralDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -53,28 +49,11 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, mistralDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
index 6023b4abe8..cb914d8c85 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
@@ -3,13 +3,12 @@ package provider
import (
"errors"
"fmt"
- "net/http"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
+ "net/http"
)
// moonshotProvider is the provider for Moonshot AI service.
@@ -58,33 +57,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(moonshotChatCompletionPath)
- _ = util.OverwriteRequestHost(moonshotDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
+func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, moonshotDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
+// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法
+// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
+ if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return types.ActionContinue, err
}
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
-
if m.config.moonshotFileId == "" && m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
index 8895489fbe..5339083819 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
@@ -3,11 +3,10 @@ package provider
import (
"errors"
"fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
)
// ollamaProvider is the provider for Ollama service.
@@ -53,10 +52,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(ollamaChatCompletionPath)
- _ = util.OverwriteRequestHost(m.serviceDomain)
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -64,51 +60,11 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
- if m.config.modelMapping == nil && m.contextCache == nil {
- return types.ActionContinue, nil
- }
-
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
-
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
-
- if m.contextCache != nil {
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- } else {
- return types.ActionContinue, err
- }
- } else {
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- return types.ActionContinue, err
- }
- _ = proxywasm.ResumeHttpRequest()
- return types.ActionPause, nil
- }
+func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, m.serviceDomain)
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
index 9f34932c1a..60c835cd49 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
@@ -1,12 +1,13 @@
package provider
import (
+ "encoding/json"
"fmt"
+ "net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string {
}
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
+ return types.ActionContinue, nil
+}
+
+func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" {
switch apiName {
case ApiNameChatCompletion:
- _ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath)
+ util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
case ApiNameEmbeddings:
ctx.DontReadRequestBody()
- _ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath)
+ util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
}
} else {
- _ = util.OverwriteRequestPath(m.customPath)
+ util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
- _ = util.OverwriteRequestHost(defaultOpenaiDomain)
+ util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else {
- _ = util.OverwriteRequestHost(m.customDomain)
+ util.OverwriteRequestHostHeader(headers, m.customDomain)
}
if len(m.config.apiTokens) > 0 {
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
}
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
- return types.ActionContinue, nil
+ headers.Del("Content-Length")
}
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+ return nil, err
}
if m.config.responseJsonSchema != nil {
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
@@ -101,27 +110,5 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
request.StreamOptions.IncludeUsage = true
}
}
- if m.contextCache == nil {
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- return types.ActionContinue, nil
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return json.Marshal(request)
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
index c9d59fd035..f74805c912 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
@@ -1,14 +1,17 @@
package provider
import (
+ "encoding/json"
"errors"
"math/rand"
+ "net/http"
"strings"
+ "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
+ "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
-
- "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type ApiName string
@@ -110,14 +113,32 @@ type Provider interface {
GetProviderType() string
}
+type ApiNameHandler interface {
+ GetApiName(path string) ApiName
+}
+
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
+type TransformRequestHeadersHandler interface {
+ TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
+}
+
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
+type TransformRequestBodyHandler interface {
+ TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
+}
+
+// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
+// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
+type TransformRequestBodyHeadersHandler interface {
+ TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
+}
+
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
@@ -143,6 +164,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
+ // @Title zh-CN apiToken 故障切换
+ // @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
+ failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -289,6 +313,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
}
+
+ failoverJson := json.Get("failover")
+ c.failover = &failover{
+ enabled: false,
+ }
+ if failoverJson.Exists() {
+ c.failover.FromJson(failoverJson)
+ }
}
func (c *ProviderConfig) Validate() error {
@@ -304,6 +336,12 @@ func (c *ProviderConfig) Validate() error {
}
}
+ if c.failover.enabled {
+ if err := c.failover.Validate(); err != nil {
+ return err
+ }
+ }
+
if c.typ == "" {
return errors.New("missing type in provider config")
}
@@ -355,6 +393,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
+func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
+ switch req := request.(type) {
+ case *chatCompletionRequest:
+ if err := decodeChatCompletionRequest(body, req); err != nil {
+ return err
+ }
+
+ streaming := req.Stream
+ if streaming {
+ _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
+ }
+
+ return c.setRequestModel(ctx, req, log)
+ case *embeddingsRequest:
+ if err := decodeEmbeddingsRequest(body, req); err != nil {
+ return err
+ }
+ return c.setRequestModel(ctx, req, log)
+ default:
+ return errors.New("unsupported request type")
+ }
+}
+
+func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
+ var model *string
+
+ switch req := request.(type) {
+ case *chatCompletionRequest:
+ model = &req.Model
+ case *embeddingsRequest:
+ model = &req.Model
+ default:
+ return errors.New("unsupported request type")
+ }
+
+ return c.mapModel(ctx, model, log)
+}
+
+func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
+ if *model == "" {
+ return errors.New("missing model in request")
+ }
+ ctx.SetContext(ctxKeyOriginalRequestModel, *model)
+
+ mappedModel := getMappedModel(*model, c.modelMapping, log)
+ if mappedModel == "" {
+ return errors.New("model becomes empty after applying the configured mapping")
+ }
+
+ *model = mappedModel
+ ctx.SetContext(ctxKeyFinalRequestModel, *model)
+ return nil
+}
+
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
if len(mappedModel) != 0 {
@@ -391,3 +483,62 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
+
+func (c *ProviderConfig) handleRequestBody(
+ provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
+) (types.Action, error) {
+ // use original protocol
+ if c.protocol == protocolOriginal {
+ return types.ActionContinue, nil
+ }
+
+ // use openai protocol
+ var err error
+ if handler, ok := provider.(TransformRequestBodyHandler); ok {
+ body, err = handler.TransformRequestBody(ctx, apiName, body, log)
+ } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
+ headers := util.GetOriginalHttpHeaders()
+ body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
+ util.ReplaceOriginalHttpHeaders(headers)
+ } else {
+ body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
+ }
+
+ if err != nil {
+ return types.ActionContinue, err
+ }
+
+ if apiName == ApiNameChatCompletion {
+ if c.context == nil {
+ return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
+ }
+ err = contextCache.GetContextFromFile(ctx, provider, body, log)
+
+ if err == nil {
+ return types.ActionPause, nil
+ }
+ return types.ActionContinue, err
+ }
+ return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
+}
+
+func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
+ if handler, ok := provider.(TransformRequestHeadersHandler); ok {
+ originalHeaders := util.GetOriginalHttpHeaders()
+ handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
+ util.ReplaceOriginalHttpHeaders(originalHeaders)
+ }
+}
+
+func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
+ var request interface{}
+ if apiName == ApiNameChatCompletion {
+ request = &chatCompletionRequest{}
+ } else {
+ request = &embeddingsRequest{}
+ }
+ if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
+ return nil, err
+ }
+ return json.Marshal(request)
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
index f673fa98b2..771feeb51e 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math"
+ "net/http"
"reflect"
"strings"
"time"
@@ -58,35 +59,50 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide
}
type qwenProvider struct {
- config ProviderConfig
-
+ config ProviderConfig
contextCache *contextCache
}
+func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestHostHeader(headers, qwenDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+
+ if m.config.qwenEnableCompatible {
+ util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
+ } else if apiName == ApiNameChatCompletion {
+ util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
+ } else if apiName == ApiNameEmbeddings {
+ util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
+ }
+
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
+}
+
+func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
+ if apiName == ApiNameChatCompletion {
+ return m.onChatCompletionRequestBody(ctx, body, headers, log)
+ } else {
+ return m.onEmbeddingsRequestBody(ctx, body, log)
+ }
+}
+
func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
- _ = util.OverwriteRequestHost(qwenDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
+ if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
+ return types.ActionContinue, errUnsupportedApiName
+ }
+
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
- } else if m.config.qwenEnableCompatible {
- _ = util.OverwriteRequestPath(qwenCompatiblePath)
- } else if apiName == ApiNameChatCompletion {
- _ = util.OverwriteRequestPath(qwenChatCompletionPath)
- } else if apiName == ApiNameEmbeddings {
- _ = util.OverwriteRequestPath(qwenTextEmbeddingPath)
- } else {
- return types.ActionContinue, errUnsupportedApiName
}
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
-
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -121,65 +137,23 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}
return types.ActionContinue, nil
}
- if apiName == ApiNameChatCompletion {
- return m.onChatCompletionRequestBody(ctx, body, log)
- }
- if apiName == ApiNameEmbeddings {
- return m.onEmbeddingsRequestBody(ctx, body, log)
- }
- return types.ActionContinue, errUnsupportedApiName
-}
-func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
- if m.config.protocol == protocolOriginal {
- if m.config.context == nil {
- return types.ActionContinue, nil
- }
-
- request := &qwenTextGenRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
-
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
-
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- m.insertContextMessage(request, content, false)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
+ return types.ActionContinue, errUnsupportedApiName
}
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
+ err := m.config.parseRequestAndMapModel(ctx, request, body, log)
+ if err != nil {
+ return nil, err
}
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
// Use the qwen multimodal model generation API
if strings.HasPrefix(request.Model, qwenVlModelPrefixName) {
- _ = util.OverwriteRequestPath(qwenMultimodalGenerationPath)
+ util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath)
}
streaming := request.Stream
@@ -191,62 +165,20 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
- if m.config.context == nil {
- qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
- if streaming {
- ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
- }
- return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
- }
-
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
- if streaming {
- ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
- }
- if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.buildQwenTextGenerationRequest(ctx, request, streaming)
}
-func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
+func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
+ if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
+ return nil, err
}
- log.Debugf("=== embeddings request: %v", request)
-
- model := request.Model
- if model == "" {
- return types.ActionContinue, errors.New("missing model in the request")
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- request.Model = mappedModel
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
-
- if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil {
- return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
- } else {
- return types.ActionContinue, err
+ qwenRequest, err := m.buildQwenTextEmbeddingRequest(request)
+ if err != nil {
+ return nil, err
}
+ return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -375,7 +307,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
-func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
+func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
messages := make([]qwenMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
@@ -397,6 +329,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
Tools: origRequest.Tools,
},
}
+
+ if streaming {
+ ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput)
+ }
+
if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName {
builder := strings.Builder{}
for _, fileId := range m.config.qwenFileIds {
@@ -406,13 +343,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
builder.WriteString("fileid://")
builder.WriteString(fileId)
}
- contextMessageId := m.insertContextMessage(request, builder.String(), true)
- if contextMessageId == 0 {
- // The context message cannot come first. We need to add another dummy system message before it.
- request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
+
+ body, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("unable to marshal request: %v", err)
}
+
+ return m.insertHttpContextMessage(body, builder.String(), true)
}
- return request
+ return json.Marshal(request)
}
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
@@ -569,7 +508,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return nil
}
-func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int {
+func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
+ request := &qwenTextGenRequest{}
+ if err := json.Unmarshal(body, request); err != nil {
+ return nil, fmt.Errorf("unable to unmarshal request: %v", err)
+ }
+
fileMessage := qwenMessage{
Role: roleSystem,
Content: content,
@@ -586,10 +530,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
}
if firstNonSystemMessageIndex == 0 {
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
- return 0
} else if !onlyOneSystemBeforeFile {
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
- return firstNonSystemMessageIndex
} else {
builder := strings.Builder{}
for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] {
@@ -599,8 +541,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
builder.WriteString(message.StringContent())
}
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
- return 1
+ firstNonSystemMessageIndex = 1
+ }
+
+ if firstNonSystemMessageIndex == 0 {
+ // The context message cannot come first. We need to add another dummy system message before it.
+ request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
}
+
+ return json.Marshal(request)
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
@@ -804,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
}
}
}
+
+func (m *qwenProvider) GetApiName(path string) ApiName {
+ switch {
+ case strings.Contains(path, qwenChatCompletionPath),
+ strings.Contains(path, qwenMultimodalGenerationPath),
+ strings.Contains(path, qwenCompatiblePath):
+ return ApiNameChatCompletion
+ case strings.Contains(path, qwenTextEmbeddingPath):
+ return ApiNameEmbeddings
+ default:
+ return ""
+ }
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go
index 19060849ac..dd9864702e 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go
@@ -3,7 +3,6 @@ package provider
import (
"encoding/json"
"fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
@@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er
return nil
}
+func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
+ if err := json.Unmarshal(body, request); err != nil {
+ return fmt.Errorf("unable to unmarshal request: %v", err)
+ }
+ return nil
+}
+
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
body, err := json.Marshal(request)
if err != nil {
@@ -31,6 +37,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err
}
+func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
+ log.Debugf("request body: %s", string(body))
+ err := proxywasm.ReplaceHttpRequestBody(body)
+ if err != nil {
+ return fmt.Errorf("unable to replace the original request body: %v", err)
+ }
+ return nil
+}
+
func insertContextMessage(request *chatCompletionRequest, content string) {
fileMessage := chatMessage{
Role: roleSystem,
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
index fc266dfbaa..c2e013643c 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
@@ -2,8 +2,8 @@ package provider
import (
"encoding/json"
- "errors"
"fmt"
+ "net/http"
"strings"
"time"
@@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestHost(sparkHost)
- _ = util.OverwriteRequestPath(sparkChatCompletionPath)
- _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ p.config.handleRequestHeaders(p, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- // 使用Spark协议
- if p.config.protocol == protocolOriginal {
- request := &sparkRequest{}
- if err := json.Unmarshal(body, request); err != nil {
- return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("request model is empty")
- }
- // 目前星火在模型名称错误时,也会调用generalv3,这里还是按照输入的模型名称设置响应里的模型名称
- ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- } else {
- // 使用openai协议
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- if request.Model == "" {
- return types.ActionContinue, errors.New("missing model in chat completion request")
- }
- // 映射模型
- mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- request.Model = mappedModel
- return types.ActionContinue, replaceJsonRequestBody(request, log)
- }
+ return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -205,3 +172,11 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
+
+func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, sparkHost)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
+ headers.Del("Accept-Encoding")
+ headers.Del("Content-Length")
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
index dd6792ed65..1ee01abe62 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
@@ -2,12 +2,10 @@ package provider
import (
"errors"
- "fmt"
-
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "net/http"
)
const (
@@ -45,10 +43,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(stepfunChatCompletionPath)
- _ = util.OverwriteRequestHost(stepfunDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +51,12 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, stepfunDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
index 287945d903..7cb05a9388 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
@@ -2,11 +2,10 @@ package provider
import (
"errors"
- "fmt"
+ "net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -45,10 +44,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(yiChatCompletionPath)
- _ = util.OverwriteRequestHost(yiDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +52,12 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
- }
- return types.ActionContinue, err
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, yiChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, yiDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
index 9640cd02f4..40fbe4ef88 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
@@ -2,11 +2,11 @@ package provider
import (
"errors"
- "fmt"
+ "net/http"
+ "strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -44,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath)
- _ = util.OverwriteRequestHost(zhipuAiDomain)
- _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
- _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+ m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -55,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- if m.contextCache == nil {
- return types.ActionContinue, nil
- }
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return types.ActionContinue, err
- }
- err := m.contextCache.GetContent(func(content string, err error) {
- defer func() {
- _ = proxywasm.ResumeHttpRequest()
- }()
- if err != nil {
- log.Errorf("failed to load context file: %v", err)
- _ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
- }
- insertContextMessage(request, content)
- if err := replaceJsonRequestBody(request, log); err != nil {
- _ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
- }
- }, log)
- if err == nil {
- return types.ActionPause, nil
+ return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
+}
+
+func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
+ util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath)
+ util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
+ util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
+ headers.Del("Content-Length")
+}
+
+func (m *zhipuAiProvider) GetApiName(path string) ApiName {
+ if strings.Contains(path, zhipuAiChatCompletionPath) {
+ return ApiNameChatCompletion
}
- return types.ActionContinue, err
+ return ""
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go
index 43135ec0a2..f0d4c0ce7c 100644
--- a/plugins/wasm-go/extensions/ai-proxy/util/http.go
+++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go
@@ -1,6 +1,10 @@
package util
-import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
+import (
+ "net/http"
+
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
+)
const (
HeaderContentType = "Content-Type"
@@ -21,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string {
return headers
}
-func OverwriteRequestHost(host string) error {
- if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
- _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost)
- }
- return proxywasm.ReplaceHttpRequestHeader(":authority", host)
-}
-
func OverwriteRequestPath(path string) error {
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath)
@@ -43,3 +40,56 @@ func OverwriteRequestAuthorization(credential string) error {
}
return proxywasm.ReplaceHttpRequestHeader("Authorization", credential)
}
+
+func OverwriteRequestHostHeader(headers http.Header, host string) {
+ if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
+ headers.Set("X-ENVOY-ORIGINAL-HOST", originHost)
+ }
+ headers.Set(":authority", host)
+}
+
+func OverwriteRequestPathHeader(headers http.Header, path string) {
+ if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
+ headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
+ }
+ headers.Set(":path", path)
+}
+
+func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
+ if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
+ if originAuth := headers.Get("Authorization"); originAuth != "" {
+ headers.Set("X-HI-ORIGINAL-AUTH", originAuth)
+ }
+ }
+ headers.Set("Authorization", credential)
+}
+
+func HeaderToSlice(header http.Header) [][2]string {
+ slice := make([][2]string, 0, len(header))
+ for key, values := range header {
+ for _, value := range values {
+ slice = append(slice, [2]string{key, value})
+ }
+ }
+ return slice
+}
+
+func SliceToHeader(slice [][2]string) http.Header {
+ header := make(http.Header)
+ for _, pair := range slice {
+ key := pair[0]
+ value := pair[1]
+ header.Add(key, value)
+ }
+ return header
+}
+
+func GetOriginalHttpHeaders() http.Header {
+ originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
+ return SliceToHeader(originalHeaders)
+}
+
+func ReplaceOriginalHttpHeaders(headers http.Header) {
+ modifiedHeaders := HeaderToSlice(headers)
+ _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
+}
diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go
index 96600192b1..e797394b54 100644
--- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go
+++ b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go
@@ -45,6 +45,19 @@ func (c RouteCluster) HostName() string {
return GetRequestHost()
}
+type TargetCluster struct {
+ Host string
+ Cluster string
+}
+
+func (c TargetCluster) ClusterName() string {
+ return c.Cluster
+}
+
+func (c TargetCluster) HostName() string {
+ return c.Host
+}
+
type K8sCluster struct {
ServiceName string
Namespace string