Skip to content

Commit

Permalink
ai proxy return unified status in header phase
Browse files Browse the repository at this point in the history
  • Loading branch information
StarryVae committed Dec 16, 2024
1 parent ec39d56 commit 2a22b03
Show file tree
Hide file tree
Showing 27 changed files with 94 additions and 88 deletions.
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Always return types.HeaderStopIteration to support fallback routing,
// as long as onHttpRequestBody can be called.
return types.HeaderStopIteration
}
return action
return types.ActionContinue
}

util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string {
return providerTypeAi360
}

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
return nil
}

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string {
return providerTypeAzure
}

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string {
return providerTypeBaichuan
}

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ func (g *baiduProvider) GetProviderType() string {
return providerTypeBaidu
}

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func (c *claudeProvider) GetProviderType() string {
return providerTypeClaude
}

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string {
return providerTypeCloudflare
}

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
11 changes: 6 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package provider
import (
"encoding/json"
"errors"
"net/http"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)

const (
Expand Down Expand Up @@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string {
return providerTypeCohere
}

func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
5 changes: 2 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/coze.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

const (
Expand Down Expand Up @@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string {
return providerTypeCoze
}

func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ func (d *deeplProvider) GetProviderType() string {
return providerTypeDeepl
}

func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return types.HeaderStopIteration, nil
return nil
}

func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
9 changes: 5 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package provider

import (
"errors"
"net/http"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)

// deepseekProvider is the provider for deepseek Ai service.
Expand Down Expand Up @@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string {
return providerTypeDeepSeek
}

func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
11 changes: 6 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package provider

import (
"errors"
"net/http"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)

const (
Expand Down Expand Up @@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string {
return providerTypeDoubao
}

func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini
}

func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
return nil
}

func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
11 changes: 6 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package provider

import (
"errors"
"net/http"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)

// githubProvider is the provider for GitHub OpenAI service.
Expand Down Expand Up @@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string {
return providerTypeGithub
}

func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
return nil
}

func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/groq.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}

func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string {
return providerTypeHunyuan
}

func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
return nil
}

func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string {
return providerTypeMinimax
}

func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
return nil
}

func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
9 changes: 5 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package provider

import (
"errors"
"net/http"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)

const (
Expand Down Expand Up @@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string {
return providerTypeMistral
}

func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string {
return providerTypeMoonshot
}

func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
return nil
}

func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
Expand Down
Loading

0 comments on commit 2a22b03

Please sign in to comment.