From 63cd3f05f2a6e4e583a132206eeefb0dc15ec21e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 15 Nov 2023 21:05:14 +0800 Subject: [PATCH] support tts --- common/model-ratio.go | 4 ++- common/utils.go | 9 +++++++ controller/relay-audio.go | 56 ++++++++++++++++++++++++++++++++------- controller/relay-utils.go | 9 +++++++ controller/relay.go | 6 +++++ middleware/distributor.go | 9 ++++--- router/relay-router.go | 1 + 7 files changed, 81 insertions(+), 13 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index f1cc07dab..820f22833 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // 1k characters -> $0.015 + "tts-1-hd": 15, // 1k characters -> $0.03 "davinci": 10, "curie": 10, "babbage": 10, diff --git a/common/utils.go b/common/utils.go index 21bec8f59..d65d42a67 100644 --- a/common/utils.go +++ b/common/utils.go @@ -207,3 +207,12 @@ func String2Int(str string) int { } return num } + +func StringsContains(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 13d9c9fdd..e959e73df 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -11,10 +11,19 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" ) +var availableVoices = []string{ + "alloy", + "echo", + "fable", + "onyx", + "nova", + "shimmer", +} + func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - audioModel := "whisper-1" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode userId := c.GetInt("id") group := c.GetString("group") + var audioRequest AudioRequest + err := common.UnmarshalBodyReusable(c, &audioRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + // request validation + if audioRequest.Model == "" { + return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + } + + if strings.HasPrefix(audioRequest.Model, "tts-1") { + if audioRequest.Voice == "" { + return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) + } + if !common.StringsContains(availableVoices, audioRequest.Voice) { + return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) + } + } + preConsumedTokens := common.PreConsumedQuota - modelRatio := common.GetModelRatio(audioModel) + modelRatio := common.GetModelRatio(audioRequest.Model) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) @@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } - if modelMap[audioModel] != "" { - audioModel = modelMap[audioModel] + if modelMap[audioRequest.Model] != "" { + audioRequest.Model = modelMap[audioRequest.Model] } } @@ -97,7 +126,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode defer func(ctx context.Context) { go func() { - quota := countTokenText(audioResponse.Text, audioModel) + var quota int + if strings.HasPrefix(audioRequest.Model, "tts-1") { + quota = countAudioToken(audioRequest.Input, audioRequest.Model) + } else { + quota = countAudioToken(audioResponse.Text, audioRequest.Model) + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { @@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) @@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + if strings.HasPrefix(audioRequest.Model, "tts-1") { + + } else { + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index d2f3d2fa3..40aa54726 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -10,6 +10,7 @@ import ( "one-api/common" "strconv" "strings" + "unicode/utf8" ) var stopFinishReason = "stop" @@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int { return 0 } +func countAudioToken(text string, model string) int { + if strings.HasPrefix(model, "tts") { + return utf8.RuneCountInString(text) + } else { + return countTokenText(text, model) + } +} + func countTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) diff --git a/controller/relay.go b/controller/relay.go index c505c22e7..2ca2bc2da 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string { return input } +type AudioRequest struct { + Model string `json:"model"` + Voice string `json:"voice"` + Input string `json:"input"` +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` diff --git a/middleware/distributor.go b/middleware/distributor.go index c49a40d29..c9d8be8cf 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) { if modelRequest.Model == "" { modelRequest.Model = "midjourney" } - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - err = common.UnmarshalBodyReusable(c, &modelRequest) } + err = common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = "tts-1" + } else { + modelRequest.Model = "whisper-1" + } } } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index c97ea31d3..3916503e1 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)