Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: retry failed request #1590

Merged
merged 13 commits into from
Dec 26, 2024
25 changes: 17 additions & 8 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ description: AI 代理插件配置参考
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |

`context`的配置字段说明如下:

Expand Down Expand Up @@ -78,14 +79,22 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字

`failover` 的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|-----------------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |

`retryOnFailure` 的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|-----------------|-------|-------------|
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
| retryTimeout | int | 非必填 | 5000 | 重试超时时间,单位毫秒 |

### 提供商特有配置

Expand Down
55 changes: 23 additions & 32 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
const (
pluginName = "ai-proxy"

ctxKeyApiName = "apiName"

defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
)

Expand Down Expand Up @@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
return types.ActionContinue
}

ctx.SetContext(provider.CtxKeyApiName, apiName)
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved

ctx.SetContext(ctxKeyApiName, apiName)

_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if needHandleBody || needHandleStreamingBody {
if needHandleStreamingBody {
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
}

Expand Down Expand Up @@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())

if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)

newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
if settingErr != nil {
Expand Down Expand Up @@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)

return types.ActionContinue
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
}

// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)

if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
checkStream(&ctx, log)
return action
}
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
return types.ActionContinue
headers := util.GetOriginalResponseHeaders()
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
handler.TransformResponseHeaders(ctx, apiName, headers, log)
} else {
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
}
util.ReplaceResponseHeaders(headers)

checkStream(&ctx, log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody {
ctx.DontReadResponseBody()
} else if !needHandleStreamingBody {
if !needHandleStreamingBody {
ctx.BufferResponseBody()
}

Expand All @@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))

if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
if err == nil && modifiedChunk != nil {
return modifiedChunk
Expand All @@ -249,16 +239,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
}

log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
//log.Debugf("response body: %s", string(body))

if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseBody(ctx, apiName, body, log)
if err == nil {
return action
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
if err != nil {
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(body, log); err != nil {
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
}
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
return types.ActionContinue
}
Expand Down
20 changes: 4 additions & 16 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

Expand Down Expand Up @@ -139,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
return json.Marshal(claudeRequest)
}

func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
}
if claudeResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
}
response := c.responseClaude2OpenAI(ctx, claudeResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}

func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
if c.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}

_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
return json.Marshal(response)
}

func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
if err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
}
if err := replaceHttpJsonRequestBody(body, log); err != nil {
if err := replaceRequestBody(body, log); err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
}
}
Expand Down
12 changes: 3 additions & 9 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

Expand Down Expand Up @@ -112,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
return json.Marshal(baiduRequest)
}

func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}

func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
}
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}

func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {
Expand Down
17 changes: 11 additions & 6 deletions plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
Expand All @@ -29,7 +29,7 @@ type failover struct {
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
Expand Down Expand Up @@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
}
Expand Down Expand Up @@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}

func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
if c.isFailoverEnabled() {
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
}
if c.isRetryOnFailureEnabled() && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
c.retryFailedRequest(activeProvider, ctx, log)
return types.HeaderStopAllIterationAndWatermark
}
return types.ActionContinue
}

func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
Expand All @@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
} else {
apiToken = c.GetRandomToken()
}
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
log.Debugf("Use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}

Expand Down
31 changes: 10 additions & 21 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(geminiRequest)
}

func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if g.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}

_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}

func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
Expand Down Expand Up @@ -148,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return []byte(modifiedResponseChunk), nil
}

func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
} else {
return g.onEmbeddingsResponseBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}

func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildChatCompletionResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}

func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}

func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
Expand Down
Loading
Loading