Skip to content

Commit

Permalink
intergrate gemini yet another time
Browse files Browse the repository at this point in the history
  • Loading branch information
t4ke0 committed Mar 14, 2024
1 parent b8cf65b commit 904f53f
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 150 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/sashabaranov/go-openai v1.20.3
github.com/tidwall/gjson v1.14.4
github.com/xqdoo00o/OpenAIAuth v0.0.0-20240313154058-7c1d960f325b
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
Expand Down
11 changes: 0 additions & 11 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ func nightmare(c *gin.Context) {
return
}

// err := c.BindJSON(&original_request)
// if err != nil {
// c.JSON(400, gin.H{"error": gin.H{
// "message": "Request must be proper JSON",
// "type": "invalid_request_error",
// "param": nil,
// "code": err.Error(),
// }})
// return
// }

c.Request.Body = io.NopCloser(bytes.NewReader(buff.Bytes()))
if original_request.Model == "gemini-pro" {
api.ChatProxyHandler(c)
Expand Down
222 changes: 98 additions & 124 deletions internal/gemini/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/generative-ai-go/genai"
openai "github.com/sashabaranov/go-openai"
"github.com/zhu327/gemini-openai-proxy/pkg/adapter"
"google.golang.org/api/iterator"
"google.golang.org/api/option"

Expand All @@ -30,28 +31,114 @@ func IndexHandler(c *gin.Context) {
}

func ModelListHandler(c *gin.Context) {
model := openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Dot5Turbo,
Object: "model",
OwnedBy: "openai",
}

c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": []any{model},
"data": []any{
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Dot5Turbo,
Object: "model",
OwnedBy: "openai",
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT4VisionPreview,
Object: "model",
OwnedBy: "openai",
},
},
})
}

func ModelRetrieveHandler(c *gin.Context) {
model := openai.Model{
model := c.Param("model")
c.JSON(http.StatusOK, openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Dot5Turbo,
ID: model,
Object: "model",
OwnedBy: "openai",
})
}

func ChatProxyHandler(c *gin.Context) {
openaiAPIKey, err := getRandomAPIKey()
if err != nil {
log.Printf("Error getting API key: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve API key from gemini-api-key.json file"})
return
}
var req = &adapter.ChatCompletionRequest{}
// Bind the JSON data from the request to the struct
if err := c.ShouldBindJSON(req); err != nil {
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

c.JSON(http.StatusOK, model)
ctx := c.Request.Context()
client, err := genai.NewClient(ctx, option.WithAPIKey(openaiAPIKey))
if err != nil {
log.Printf("new genai client error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
defer client.Close()

var gemini adapter.GenaiModelAdapter
switch req.Model {
case openai.GPT4VisionPreview:
gemini = adapter.NewGeminiProVisionAdapter(client)
default:
gemini = adapter.NewGeminiProAdapter(client)
}

if !req.Stream {
resp, err := gemini.GenerateContent(ctx, req)
if err != nil {
log.Printf("genai generate content error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

c.JSON(http.StatusOK, resp)
return
}

dataChan, err := gemini.GenerateStreamContent(ctx, req)
if err != nil {
log.Printf("genai generate content error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
if data, ok := <-dataChan; ok {
c.Render(-1, adapter.Event{Data: "data: " + data})
return true
}
c.Render(-1, adapter.Event{Data: "data: [DONE]"})
return false
})
}

func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

func getRandomAPIKey() (string, error) {
Expand Down Expand Up @@ -214,116 +301,3 @@ func VisionProxyHandler(c *gin.Context) {
return false
})
}

func ChatProxyHandler(c *gin.Context) {
openaiAPIKey, err := getRandomAPIKey()
if err != nil {
// Handle the error, for example, log it and return from the function
log.Printf("Error getting API key: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve API key from gemini-api-key.json file"})
return
}

println("use api key:" + openaiAPIKey)

if openaiAPIKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Api key not found!"})
return
}

var req openai.ChatCompletionRequest
// Bind the JSON data from the request to the struct
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

if len(req.Messages) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "request message must not be empty!"})
return
}

ctx := c.Request.Context()
client, err := genai.NewClient(ctx, option.WithAPIKey(openaiAPIKey))
if err != nil {
log.Printf("new genai client error %v\n", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer client.Close()

model := client.GenerativeModel(protocol.GeminiPro)
protocol.SetGenaiModelByOpenaiRequest(model, req)

cs := model.StartChat()
protocol.SetGenaiChatByOpenaiRequest(cs, req)

prompt := genai.Text(req.Messages[len(req.Messages)-1].Content)

if !req.Stream {
genaiResp, err := cs.SendMessage(ctx, prompt)
if err != nil {
log.Printf("genai send message error %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

openaiResp := protocol.GenaiResponseToOpenaiResponse(genaiResp)
c.JSON(http.StatusOK, openaiResp)
return
}

iter := cs.SendMessageStream(ctx, prompt)
dataChan := make(chan string)
go func() {
defer close(dataChan)

defer func() {
if r := recover(); r != nil {
log.Println("Recovered. Error:\n", r)
}
}()

respID := util.GetUUID()
created := time.Now().Unix()

for {
genaiResp, err := iter.Next()
if err == iterator.Done {
break
}

if err != nil {
log.Printf("genai get stream message error %v\n", err)
dataChan <- fmt.Sprintf(`{"error": "%s"}`, err.Error())
break
}

openaiResp := protocol.GenaiResponseToStreamCompletionResponse(genaiResp, respID, created)
resp, _ := json.Marshal(openaiResp)
dataChan <- string(resp)

if len(openaiResp.Choices) > 0 && openaiResp.Choices[0].FinishReason != nil {
break
}
}
}()

setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
if data, ok := <-dataChan; ok {
c.Render(-1, protocol.Event{Data: "data: " + data})
return true
}
c.Render(-1, protocol.Event{Data: "data: [DONE]"})
return false
})
}

func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
18 changes: 3 additions & 15 deletions internal/gemini/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,17 @@ func Register(router *gin.Engine) {
// Configure CORS to allow all methods and all origins
config := cors.DefaultConfig()
config.AllowAllOrigins = true
config.AllowHeaders = []string{
"Accept",
"Authorization",
"Content-Type",
"Accept-Language",
"Content-Language",
"DPR",
"Downlink",
"Save-Data",
"Viewport-Width",
"Width",
"X-Requested-With",
}
config.AllowHeaders = []string{"*"}
config.AllowCredentials = true
config.OptionsResponseStatusCode = http.StatusOK
router.Use(cors.New(config))

// Define a route and its handler
router.GET("/", IndexHandler)
// openai model
router.GET("/v1/models", ModelListHandler)
router.GET("/v1/models/gpt-3.5-turbo", ModelRetrieveHandler)
router.GET("/v1/models/:model", ModelRetrieveHandler)

// openai chat
router.POST("/v1/chat/completions", ChatProxyHandler)
router.POST("/v1/chat/completions/vision", VisionProxyHandler)
}

0 comments on commit 904f53f

Please sign in to comment.