Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ability to mimic /v1/models/ api #68

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion azure/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package azure

import (
"bytes"
"encoding/json"
"fmt"
"github.com/stulzq/azure-openai-proxy/util"
"io"
"log"
"net/http"
Expand All @@ -21,6 +21,85 @@ func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
}
}

type DeploymentInfo struct {
Data []map[string]interface{} `json:"data"`
Object string `json:"object"`
}

func ModelProxy(c *gin.Context) {
// Create a channel to receive the results of each request
results := make(chan []map[string]interface{}, len(ModelDeploymentConfig))

// Send a request for each deployment in the map
for _, deployment := range ModelDeploymentConfig {
go func(deployment DeploymentConfig) {
// Create the request
req, err := http.NewRequest(http.MethodGet, deployment.Endpoint+"/openai/deployments?api-version=2022-12-01", nil)
if err != nil {
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}

// Set the auth header
req.Header.Set(AuthHeaderKey, deployment.ApiKey)

// Send the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Printf("error sending request for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("unexpected status code %d for deployment %s", resp.StatusCode, deployment.DeploymentName)
results <- nil
return
}

// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("error reading response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}

// Parse the response body as JSON
var deplotmentInfo DeploymentInfo
err = json.Unmarshal(body, &deplotmentInfo)
if err != nil {
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}
results <- deplotmentInfo.Data
}(deployment)
}

// Wait for all requests to finish and collect the results
var allResults []map[string]interface{}
for i := 0; i < len(ModelDeploymentConfig); i++ {
result := <-results
if result != nil {
allResults = append(allResults, result...)
}
}
var info = DeploymentInfo{Data: allResults, Object: "list"}
combinedResults, err := json.Marshal(info)
if err != nil {
log.Printf("error marshalling results: %v", err)
util.SendError(c, err)
return
}

// Set the response headers and body
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(combinedResults))
}

// Proxy Azure OpenAI
func Proxy(c *gin.Context, requestConverter RequestConverter) {
if c.Request.Method == http.MethodOptions {
Expand Down
1 change: 1 addition & 0 deletions cmd/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func registerRoute(r *gin.Engine) {
})
apiBase := viper.GetString("api_base")
stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
r.GET(stripPrefixConverter.Prefix+"/models", azure.ModelProxy)
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
apiBasedRouter := r.Group(apiBase)
{
Expand Down
Loading