Skip to content

Commit

Permalink
support tts
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Nov 15, 2023
1 parent 16ad764 commit 63cd3f0
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 13 deletions.
4 changes: 3 additions & 1 deletion common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
Expand Down
9 changes: 9 additions & 0 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,12 @@ func String2Int(str string) int {
}
return num
}

func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
56 changes: 47 additions & 9 deletions controller/relay-audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,48 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
)

var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}

func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
audioModel := "whisper-1"

tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")

var audioRequest AudioRequest
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}

// request validation
if audioRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}

if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}

preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
Expand Down Expand Up @@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
}

Expand Down Expand Up @@ -97,7 +126,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode

defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
var quota int
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = countAudioToken(audioRequest.Input, audioRequest.Model)
} else {
quota = countAudioToken(audioResponse.Text, audioRequest.Model)
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
Expand All @@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
Expand All @@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
if strings.HasPrefix(audioRequest.Model, "tts-1") {

} else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
}

resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
Expand Down
9 changes: 9 additions & 0 deletions controller/relay-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"one-api/common"
"strconv"
"strings"
"unicode/utf8"
)

var stopFinishReason = "stop"
Expand Down Expand Up @@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int {
return 0
}

func countAudioToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
} else {
return countTokenText(text, model)
}
}

func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
Expand Down
6 changes: 6 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
return input
}

type AudioRequest struct {
Model string `json:"model"`
Voice string `json:"voice"`
Input string `json:"input"`
}

type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Expand Down
9 changes: 6 additions & 3 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) {
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
}
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
Expand All @@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) {
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
Expand Down
1 change: 1 addition & 0 deletions router/relay-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
Expand Down

0 comments on commit 63cd3f0

Please sign in to comment.