Skip to content

Commit

Permalink
feat: support multi model
Browse files Browse the repository at this point in the history
  • Loading branch information
stulzq committed Mar 28, 2023
1 parent 7b9dff6 commit 1968022
Show file tree
Hide file tree
Showing 16 changed files with 162 additions and 112 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
FROM alpine:3

EXPOSE 8080
COPY ./bin/azure-openai-proxy /usr/bin

ENTRYPOINT ["/usr/bin/azure-openai-proxy"]
43 changes: 0 additions & 43 deletions apis/chat.go

This file was deleted.

9 changes: 0 additions & 9 deletions apis/vars.go

This file was deleted.

58 changes: 58 additions & 0 deletions azure/init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package azure

import (
"log"
"net/url"
"os"
"regexp"
"strings"

"github.com/stulzq/azure-openai-proxy/constant"
)

const (
AuthHeaderKey = "api-key"
)

var (
AzureOpenAIEndpoint = ""
AzureOpenAIEndpointParse *url.URL

AzureOpenAIAPIVer = ""

AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
}
fallbackModelMapper = regexp.MustCompile(`[.:]`)
)

func init() {
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)

if AzureOpenAIAPIVer == "" {
AzureOpenAIAPIVer = "2023-03-15-preview"
}

var err error
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
if err != nil {
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
}

if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
for _, pair := range strings.Split(v, ",") {
info := strings.Split(pair, "=")
if len(info) != 2 {
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
}

AzureOpenAIModelMapper[info[0]] = info[1]
}
}

log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
}
81 changes: 81 additions & 0 deletions azure/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package azure

import (
"bytes"
"fmt"
"github.com/bytedance/sonic"
"github.com/pkg/errors"
"github.com/stulzq/azure-openai-proxy/util"
"io"
"log"
"net/http"
"net/http/httputil"
"path"
"strings"

"github.com/gin-gonic/gin"
)

// Proxy Azure OpenAI
func Proxy(c *gin.Context) {
// improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go
director := func(req *http.Request) {
if req.Body == nil {
util.SendError(c, errors.New("request body is empty"))
return
}
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewBuffer(body))

// get model from body
model, err := sonic.Get(body, "model")
if err != nil {
util.SendError(c, errors.Wrap(err, "get model error"))
return
}

deployment, err := model.String()
if err != nil {
util.SendError(c, errors.Wrap(err, "get deployment error"))
return
}
deployment = GetDeploymentByModel(deployment)

// get auth token from header
rawToken := req.Header.Get("Authorization")
token := strings.TrimPrefix(rawToken, "Bearer ")
req.Header.Set(AuthHeaderKey, token)
req.Header.Del("Authorization")

originURL := req.URL.String()
req.Host = AzureOpenAIEndpointParse.Host
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
req.URL.Host = AzureOpenAIEndpointParse.Host
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
req.URL.RawPath = req.URL.EscapedPath()

query := req.URL.Query()
query.Add("api-version", AzureOpenAIAPIVer)
req.URL.RawQuery = query.Encode()

log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
}

proxy := &httputil.ReverseProxy{Director: director}
proxy.ServeHTTP(c.Writer, c.Request)

// https://github.com/Chanzhaoyu/chatgpt-web/issues/831
if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
if _, err := c.Writer.Write([]byte{'\n'}); err != nil {
log.Printf("rewrite response error: %v", err)
}
}
}

func GetDeploymentByModel(model string) string {
if v, ok := AzureOpenAIModelMapper[model]; ok {
return v
}

return fallbackModelMapper.ReplaceAllString(model, "")
}
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -e

VERSION=v1.0.0
VERSION=v1.1.0

rm -rf bin

Expand Down
3 changes: 0 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/stulzq/azure-openai-proxy/openai"
"log"
"net/http"
"os"
Expand All @@ -13,8 +12,6 @@ import (
)

func main() {
openai.Init()

gin.SetMode(gin.ReleaseMode)
r := gin.Default()
registerRoute(r)
Expand Down
6 changes: 4 additions & 2 deletions cmd/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package main

import (
"github.com/gin-gonic/gin"
"github.com/stulzq/azure-openai-proxy/apis"
"github.com/stulzq/azure-openai-proxy/azure"
)

// registerRoute registers all routes
func registerRoute(r *gin.Engine) {
r.POST("/v1/chat/completions", apis.ChatCompletions)
// https://platform.openai.com/docs/api-reference
r.Any("*path", azure.Proxy)
}
6 changes: 3 additions & 3 deletions constant/env.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package constant

const (
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY"
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER"
)
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ require (
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
github.com/quic-go/quic-go v0.32.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
Expand Down
11 changes: 0 additions & 11 deletions openai/chat.go

This file was deleted.

20 changes: 0 additions & 20 deletions openai/init.go

This file was deleted.

17 changes: 0 additions & 17 deletions openai/vars.go

This file was deleted.

6 changes: 4 additions & 2 deletions apis/tools.go → util/response_err.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package apis
package util

import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
)

func SendError(c *gin.Context, err error) {
c.JSON(500, ApiResponse{
Expand Down
2 changes: 1 addition & 1 deletion apis/types.go → util/types.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package apis
package util

type ApiResponse struct {
Error ErrorDescription `json:"error"`
Expand Down

0 comments on commit 1968022

Please sign in to comment.