diff --git a/azure/proxy.go b/azure/proxy.go index 2cbca31..9484372 100644 --- a/azure/proxy.go +++ b/azure/proxy.go @@ -2,6 +2,7 @@ package azure import ( "bytes" + "encoding/json" "fmt" "github.com/stulzq/azure-openai-proxy/util" "io" @@ -21,6 +22,85 @@ func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc { } } +type DeploymentInfo struct { + Data []map[string]interface{} `json:"data"` + Object string `json:"object"` +} + +func ModelProxy(c *gin.Context) { + // Create a channel to receive the results of each request + results := make(chan []map[string]interface{}, len(ModelDeploymentConfig)) + + // Send a request for each deployment in the map + for _, deployment := range ModelDeploymentConfig { + go func(deployment DeploymentConfig) { + // Create the request + req, err := http.NewRequest(http.MethodGet, deployment.Endpoint+"/openai/deployments?api-version=2022-12-01", nil) + if err != nil { + log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err) + results <- nil + return + } + + // Set the auth header + req.Header.Set(AuthHeaderKey, deployment.ApiKey) + + // Send the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Printf("error sending request for deployment %s: %v", deployment.DeploymentName, err) + results <- nil + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + log.Printf("unexpected status code %d for deployment %s", resp.StatusCode, deployment.DeploymentName) + results <- nil + return + } + + // Read the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("error reading response body for deployment %s: %v", deployment.DeploymentName, err) + results <- nil + return + } + + // Parse the response body as JSON + var deplotmentInfo DeploymentInfo + err = json.Unmarshal(body, &deplotmentInfo) + if err != nil { + log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err) + results <- nil + return + } + results <- deplotmentInfo.Data + }(deployment) + } + + // Wait for all requests to finish and collect the results + var allResults []map[string]interface{} + for i := 0; i < len(ModelDeploymentConfig); i++ { + result := <-results + if result != nil { + allResults = append(allResults, result...) + } + } + var info = DeploymentInfo{Data: allResults, Object: "list"} + combinedResults, err := json.Marshal(info) + if err != nil { + log.Printf("error marshalling results: %v", err) + util.SendError(c, err) + return + } + + // Set the response headers and body + c.Header("Content-Type", "application/json") + c.String(http.StatusOK, string(combinedResults)) +} + // Proxy Azure OpenAI func Proxy(c *gin.Context, requestConverter RequestConverter) { if c.Request.Method == http.MethodOptions { diff --git a/cmd/router.go b/cmd/router.go index 1b02c5a..a1d8add 100644 --- a/cmd/router.go +++ b/cmd/router.go @@ -17,6 +17,7 @@ func registerRoute(r *gin.Engine) { }) apiBase := viper.GetString("api_base") stripPrefixConverter := azure.NewStripPrefixConverter(apiBase) + r.GET(stripPrefixConverter.Prefix+"/models", azure.ModelProxy) templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings") apiBasedRouter := r.Group(apiBase) {