From 1968022907d8e1d39f9a9118707d6eb300f91975 Mon Sep 17 00:00:00 2001 From: Zhiqiang Li Date: Tue, 28 Mar 2023 14:34:21 +0800 Subject: [PATCH] feat: support multi model --- Dockerfile | 1 + apis/chat.go | 43 -------------- apis/vars.go | 9 --- azure/init.go | 58 +++++++++++++++++++ azure/proxy.go | 81 +++++++++++++++++++++++++++ build.sh | 2 +- cmd/main.go | 3 - cmd/router.go | 6 +- constant/env.go | 6 +- go.mod | 3 + go.sum | 6 ++ openai/chat.go | 11 ---- openai/init.go | 20 ------- openai/vars.go | 17 ------ apis/tools.go => util/response_err.go | 6 +- {apis => util}/types.go | 2 +- 16 files changed, 162 insertions(+), 112 deletions(-) delete mode 100644 apis/chat.go delete mode 100644 apis/vars.go create mode 100644 azure/init.go create mode 100644 azure/proxy.go delete mode 100644 openai/chat.go delete mode 100644 openai/init.go delete mode 100644 openai/vars.go rename apis/tools.go => util/response_err.go (74%) rename {apis => util}/types.go (92%) diff --git a/Dockerfile b/Dockerfile index ebea5ea..d835b0c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ FROM alpine:3 +EXPOSE 8080 COPY ./bin/azure-openai-proxy /usr/bin ENTRYPOINT ["/usr/bin/azure-openai-proxy"] \ No newline at end of file diff --git a/apis/chat.go b/apis/chat.go deleted file mode 100644 index 02b0cea..0000000 --- a/apis/chat.go +++ /dev/null @@ -1,43 +0,0 @@ -package apis - -import ( - "github.com/gin-gonic/gin" - "github.com/pkg/errors" - "github.com/stulzq/azure-openai-proxy/openai" - "io" - "strings" -) - -// ChatCompletions xxx -// Path: /v1/chat/completions -func ChatCompletions(c *gin.Context) { - // get auth token from header - rawToken := c.GetHeader("Authorization") - token := strings.TrimPrefix(rawToken, "Bearer ") - - reqContent, err := io.ReadAll(c.Request.Body) - if err != nil { - SendError(c, errors.Wrap(err, "failed to read request body")) - return - } - - oaiResp, err := openai.ChatCompletions(token, reqContent) - if err != nil { - SendError(c, errors.Wrap(err, "failed to call Azure OpenAI")) - return - } - - // pass-through header - extraHeaders := map[string]string{} - for k, v := range oaiResp.Header { - if _, ok := ignoreHeaders[k]; ok { - continue - } - - extraHeaders[k] = strings.Join(v, ",") - } - - c.DataFromReader(oaiResp.StatusCode, oaiResp.ContentLength, oaiResp.Header.Get("Content-Type"), oaiResp.Response.Body, extraHeaders) - - _, _ = c.Writer.Write([]byte{'\n'}) // add a newline to the end of the response https://github.com/Chanzhaoyu/chatgpt-web/issues/831 -} diff --git a/apis/vars.go b/apis/vars.go deleted file mode 100644 index 086faf6..0000000 --- a/apis/vars.go +++ /dev/null @@ -1,9 +0,0 @@ -package apis - -var ( - ignoreHeaders = map[string]int{ - "Content-Type": 1, - "Transfer-Encoding": 1, - "Date": 1, - } -) diff --git a/azure/init.go b/azure/init.go new file mode 100644 index 0000000..3e4a88b --- /dev/null +++ b/azure/init.go @@ -0,0 +1,58 @@ +package azure + +import ( + "log" + "net/url" + "os" + "regexp" + "strings" + + "github.com/stulzq/azure-openai-proxy/constant" +) + +const ( + AuthHeaderKey = "api-key" +) + +var ( + AzureOpenAIEndpoint = "" + AzureOpenAIEndpointParse *url.URL + + AzureOpenAIAPIVer = "" + + AzureOpenAIModelMapper = map[string]string{ + "gpt-3.5-turbo": "gpt-35-turbo", + "gpt-3.5-turbo-0301": "gpt-35-turbo-0301", + } + fallbackModelMapper = regexp.MustCompile(`[.:]`) +) + +func init() { + AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER) + AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT) + + if AzureOpenAIAPIVer == "" { + AzureOpenAIAPIVer = "2023-03-15-preview" + } + + var err error + AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint) + if err != nil { + log.Fatal("parse AzureOpenAIEndpoint error: ", err) + } + + if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" { + for _, pair := range strings.Split(v, ",") { + info := strings.Split(pair, "=") + if len(info) != 2 { + log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair) + } + + AzureOpenAIModelMapper[info[0]] = info[1] + } + } + + log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer) + log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint) + log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper) +} diff --git a/azure/proxy.go b/azure/proxy.go new file mode 100644 index 0000000..d0b74c7 --- /dev/null +++ b/azure/proxy.go @@ -0,0 +1,81 @@ +package azure + +import ( + "bytes" + "fmt" + "github.com/bytedance/sonic" + "github.com/pkg/errors" + "github.com/stulzq/azure-openai-proxy/util" + "io" + "log" + "net/http" + "net/http/httputil" + "path" + "strings" + + "github.com/gin-gonic/gin" +) + +// Proxy Azure OpenAI +func Proxy(c *gin.Context) { + // improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go + director := func(req *http.Request) { + if req.Body == nil { + util.SendError(c, errors.New("request body is empty")) + return + } + body, _ := io.ReadAll(req.Body) + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + // get model from body + model, err := sonic.Get(body, "model") + if err != nil { + util.SendError(c, errors.Wrap(err, "get model error")) + return + } + + deployment, err := model.String() + if err != nil { + util.SendError(c, errors.Wrap(err, "get deployment error")) + return + } + deployment = GetDeploymentByModel(deployment) + + // get auth token from header + rawToken := req.Header.Get("Authorization") + token := strings.TrimPrefix(rawToken, "Bearer ") + req.Header.Set(AuthHeaderKey, token) + req.Header.Del("Authorization") + + originURL := req.URL.String() + req.Host = AzureOpenAIEndpointParse.Host + req.URL.Scheme = AzureOpenAIEndpointParse.Scheme + req.URL.Host = AzureOpenAIEndpointParse.Host + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1)) + req.URL.RawPath = req.URL.EscapedPath() + + query := req.URL.Query() + query.Add("api-version", AzureOpenAIAPIVer) + req.URL.RawQuery = query.Encode() + + log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) + } + + proxy := &httputil.ReverseProxy{Director: director} + proxy.ServeHTTP(c.Writer, c.Request) + + // https://github.com/Chanzhaoyu/chatgpt-web/issues/831 + if c.Writer.Header().Get("Content-Type") == "text/event-stream" { + if _, err := c.Writer.Write([]byte{'\n'}); err != nil { + log.Printf("rewrite response error: %v", err) + } + } +} + +func GetDeploymentByModel(model string) string { + if v, ok := AzureOpenAIModelMapper[model]; ok { + return v + } + + return fallbackModelMapper.ReplaceAllString(model, "") +} diff --git a/build.sh b/build.sh index 3e470cd..922d2e0 100644 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ set -e -VERSION=v1.0.0 +VERSION=v1.1.0 rm -rf bin diff --git a/cmd/main.go b/cmd/main.go index 39a477c..249c186 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,7 +4,6 @@ import ( "context" "github.com/gin-gonic/gin" "github.com/pkg/errors" - "github.com/stulzq/azure-openai-proxy/openai" "log" "net/http" "os" @@ -13,8 +12,6 @@ import ( ) func main() { - openai.Init() - gin.SetMode(gin.ReleaseMode) r := gin.Default() registerRoute(r) diff --git a/cmd/router.go b/cmd/router.go index 7d9caa7..0ca1b71 100644 --- a/cmd/router.go +++ b/cmd/router.go @@ -2,9 +2,11 @@ package main import ( "github.com/gin-gonic/gin" - "github.com/stulzq/azure-openai-proxy/apis" + "github.com/stulzq/azure-openai-proxy/azure" ) +// registerRoute registers all routes func registerRoute(r *gin.Engine) { - r.POST("/v1/chat/completions", apis.ChatCompletions) + // https://platform.openai.com/docs/api-reference + r.Any("*path", azure.Proxy) } diff --git a/constant/env.go b/constant/env.go index ce714c4..a0702d3 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,7 +1,7 @@ package constant const ( - ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT" - ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER" - ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY" + ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT" + ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER" + ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER" ) diff --git a/go.mod b/go.mod index 10463b1..6f4cc20 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,9 @@ require ( github.com/quic-go/qtls-go1-19 v0.2.0 // indirect github.com/quic-go/qtls-go1-20 v0.1.0 // indirect github.com/quic-go/quic-go v0.32.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect diff --git a/go.sum b/go.sum index 16e9147..389dbc2 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= diff --git a/openai/chat.go b/openai/chat.go deleted file mode 100644 index 63368a4..0000000 --- a/openai/chat.go +++ /dev/null @@ -1,11 +0,0 @@ -package openai - -import "github.com/imroc/req/v3" - -func ChatCompletions(token string, body []byte) (*req.Response, error) { - return client.R(). - SetHeader("Content-Type", "application/json"). - SetHeader(AuthHeaderKey, token). - SetBodyBytes(body). - Post(ChatCompletionsUrl) -} diff --git a/openai/init.go b/openai/init.go deleted file mode 100644 index d680c25..0000000 --- a/openai/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package openai - -import ( - "fmt" - "github.com/stulzq/azure-openai-proxy/constant" - "log" - "os" -) - -func Init() { - AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER) - AzureOpenAIDeploy = os.Getenv(constant.ENV_AZURE_OPENAI_DEPLOY) - AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT) - - log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer) - log.Println("AzureOpenAIDeploy: ", AzureOpenAIDeploy) - log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint) - - ChatCompletionsUrl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", AzureOpenAIEndpoint, AzureOpenAIDeploy, AzureOpenAIAPIVer) -} diff --git a/openai/vars.go b/openai/vars.go deleted file mode 100644 index 619585f..0000000 --- a/openai/vars.go +++ /dev/null @@ -1,17 +0,0 @@ -package openai - -import "github.com/imroc/req/v3" - -const ( - AuthHeaderKey = "api-key" -) - -var ( - AzureOpenAIEndpoint = "" - AzureOpenAIAPIVer = "" - AzureOpenAIDeploy = "" - - ChatCompletionsUrl = "" - - client = req.C() -) diff --git a/apis/tools.go b/util/response_err.go similarity index 74% rename from apis/tools.go rename to util/response_err.go index 95fa02a..9ed26d4 100644 --- a/apis/tools.go +++ b/util/response_err.go @@ -1,6 +1,8 @@ -package apis +package util -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" +) func SendError(c *gin.Context, err error) { c.JSON(500, ApiResponse{ diff --git a/apis/types.go b/util/types.go similarity index 92% rename from apis/types.go rename to util/types.go index 527476d..eb28cbb 100644 --- a/apis/types.go +++ b/util/types.go @@ -1,4 +1,4 @@ -package apis +package util type ApiResponse struct { Error ErrorDescription `json:"error"`