Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持讯飞星火微调平台微调模型 #1944

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions common/ctxkey/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ const (
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
LoraId = "lora_id"
)
55 changes: 44 additions & 11 deletions relay/adaptor/xunfei/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

Expand All @@ -28,7 +29,7 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html

func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string, xunfeiPatchId string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, Message{
Expand All @@ -38,6 +39,9 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
}
xunfeiRequest := ChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
if xunfeiPatchId != "" {
xunfeiRequest.Header.PatchId = []string{xunfeiPatchId}
}
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
Expand Down Expand Up @@ -93,7 +97,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
FinishReason: constant.StopFinishReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Id: response.Header.Sid,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Expand All @@ -102,7 +106,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
return &fullTextResponse
}

func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse, usage *model.Usage) *openai.ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
{
Expand All @@ -122,6 +126,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
Created: helper.GetTimestamp(),
Model: "SparkDesk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Usage: usage,
}
return &response
}
Expand Down Expand Up @@ -153,11 +158,12 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
}

func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId)
if err != nil {
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
}

common.SetEventStreamHeaders(c)
var usage model.Usage
c.Stream(func(w io.Writer) bool {
Expand All @@ -166,7 +172,23 @@ func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpe
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
if xunfeiResponse.Header.Code != 0 {
errMessage := fmt.Sprintf("Xunfei request failed with Sid: %s code: %d, msg: %s", xunfeiResponse.Header.Sid, xunfeiResponse.Header.Code, xunfeiResponse.Header.Message)
logger.SysError(errMessage)
mStr, err := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"message": errMessage,
"code": xunfeiResponse.Header.Code,
},
})
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(mStr), Event: "error"})
return false // 停止流式响应
}
response := streamResponseXunfei2OpenAI(&xunfeiResponse, &usage)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
Expand All @@ -183,8 +205,8 @@ func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpe
}

func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId)
if err != nil {
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
}
Expand All @@ -205,6 +227,10 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq
case stop = <-stopChan:
}
}
if xunfeiResponse.Header.Code != 0 {
return openai.ErrorWrapper(errors.New("xunfei response error: sid: "+xunfeiResponse.Header.Sid), strconv.Itoa(xunfeiResponse.Header.Code), http.StatusInternalServerError), nil

}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil
}
Expand All @@ -220,15 +246,15 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq
return nil, &usage
}

func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId, patchId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := requestOpenAI2Xunfei(textRequest, appId, domain)
data := requestOpenAI2Xunfei(textRequest, appId, domain, patchId)
err = conn.WriteJSON(data)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -300,7 +326,7 @@ func apiVersion2domain(apiVersion string) string {
return "general" + apiVersion
}

func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string, modelName string) (string, string) {
var authUrl string
domain := apiVersion2domain(apiVersion)
switch apiVersion {
Expand All @@ -310,6 +336,13 @@ func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (strin
case "v3.5-32K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
break
case "maas":
domain = modelName
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://maas-api.cn-huabei-1.xf-yun.com/v1.1/chat"), apiKey, apiSecret)
case "xingchen":
domain = modelName
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://xingcheng-api.cn-huabei-1.xf-yun.com/v1.1/chat"), apiKey, apiSecret)

default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
}
Expand Down
3 changes: 2 additions & 1 deletion relay/adaptor/xunfei/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ type Functions struct {

type ChatRequest struct {
Header struct {
AppId string `json:"app_id"`
AppId string `json:"app_id"`
PatchId []string `json:"patch_id,omitempty"`
} `json:"header"`
Parameter struct {
Chat struct {
Expand Down
4 changes: 4 additions & 0 deletions relay/meta/relay_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Meta struct {
RequestURLPath string
PromptTokens int // only for DoResponse
SystemPrompt string

//Lora_id
LoraId string
}

func GetByContext(c *gin.Context) *Meta {
Expand All @@ -48,6 +51,7 @@ func GetByContext(c *gin.Context) *Meta {
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(),
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
LoraId: c.Request.Header.Get(ctxkey.LoraId),
}
cfg, ok := c.Get(ctxkey.Config)
if ok {
Expand Down