Skip to content

Commit

Permalink
optimize ai proxy (#1603)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlanni authored Dec 19, 2024
1 parent d74d327 commit 39c007d
Show file tree
Hide file tree
Showing 11 changed files with 20 additions and 35 deletions.
20 changes: 13 additions & 7 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,34 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}

if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
return types.ActionContinue
}
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()

ctx.SetContext(ctxKeyApiName, apiName)

_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if needHandleBody || needHandleStreamingBody {
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
}

if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Always return types.HeaderStopIteration to support fallback routing,
// as long as onHttpRequestBody can be called.
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}

Expand Down
4 changes: 1 addition & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,

func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
}
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
}
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
6 changes: 2 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,13 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestHostHeader(headers, claudeDomain)

headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))

if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}

headers.Add("anthropic-version", c.config.claudeVersion)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
headers.Set("anthropic-version", c.config.claudeVersion)
}

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
2 changes: 0 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
2 changes: 0 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
headers.Del("Accept-Encoding")
}

func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
4 changes: 1 addition & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa

func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
}

func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down
2 changes: 0 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

func (m *githubProvider) GetApiName(path string) ApiName {
Expand Down
7 changes: 2 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)

// 添加 hunyuan 需要的自定义字段
headers.Add(actionKey, hunyuanChatCompletionTCAction)
headers.Add(versionKey, versionValue)

headers.Del("Accept-Encoding")
headers.Del("Content-Length")
headers.Set(actionKey, hunyuanChatCompletionTCAction)
headers.Set(versionKey, versionValue)
}

// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
Expand Down
4 changes: 0 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
} else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
}

headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
Expand All @@ -109,7 +106,6 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return nil
}

// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}

Expand Down
2 changes: 0 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

0 comments on commit 39c007d

Please sign in to comment.