From 2199cf2304bbf0cfcbb582680fc87399f7147b3e Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sun, 12 Nov 2023 23:31:59 +0800
Subject: [PATCH 1/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=81=8A=E5=A4=A9?=
=?UTF-8?q?=E6=8C=89=E9=92=AEbug?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
web/src/components/SiderBar.js | 19 +++++++++----------
1 file changed, 9 insertions(+), 10 deletions(-)
diff --git a/web/src/components/SiderBar.js b/web/src/components/SiderBar.js
index fc22b8135..95298c9ac 100644
--- a/web/src/components/SiderBar.js
+++ b/web/src/components/SiderBar.js
@@ -15,7 +15,7 @@ import {
IconLayers,
IconSetting,
IconCreditCard,
- IconSemiLogo,
+ IconComment,
IconHome,
IconImage
} from '@douyinfe/semi-icons';
@@ -36,7 +36,13 @@ let headerButtons = [
icon: ,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
-
+ {
+ text: '聊天',
+ itemKey: 'chat',
+ to: '/chat',
+ icon: ,
+ className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle',
+ },
{
text: '令牌',
itemKey: 'token',
@@ -89,14 +95,6 @@ let headerButtons = [
// }
];
-if (localStorage.getItem('chat_link')) {
- headerButtons.splice(1, 0, {
- name: '聊天',
- to: '/chat',
- icon: 'comments'
- });
-}
-
const HeaderBar = () => {
const [userState, userDispatch] = useContext(UserContext);
let navigate = useNavigate();
@@ -134,6 +132,7 @@ const HeaderBar = () => {
midjourney: "/midjourney",
setting: "/setting",
about: "/about",
+ chat: "/chat",
};
return (
Date: Sun, 12 Nov 2023 23:32:22 +0800
Subject: [PATCH 2/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0mj=E6=B8=A0=E9=81=93?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/model-ratio.go | 1 +
controller/model.go | 9 +++++++++
docker-compose.yml | 2 +-
web/src/constants/channel.constants.js | 1 +
4 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index bb2adc73d..f1cc07dab 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -14,6 +14,7 @@ import (
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
+ "midjourney": 50,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
diff --git a/controller/model.go b/controller/model.go
index f9904330c..201d64311 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -54,6 +54,15 @@ func init() {
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
+ {
+ Id: "midjourney",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "Midjourney",
+ Permission: permission,
+ Root: "midjourney",
+ Parent: nil,
+ },
{
Id: "dall-e-2",
Object: "model",
diff --git a/docker-compose.yml b/docker-compose.yml
index 9b814a037..6c5350d1d 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -2,7 +2,7 @@ version: '3.4'
services:
one-api:
- image: justsong/one-api:latest
+ image: calciumion/neko-api:main
container_name: one-api
restart: always
command: --log-dir /app/logs
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index 764077455..6da8daff0 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -1,5 +1,6 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
+ { key: 99, text: 'Midjourney-Proxy', value: 99, color: 'green' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
From b4bd9a19d9313364f6ca0bfc740ce63e9c124cf5 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sun, 12 Nov 2023 23:33:27 +0800
Subject: [PATCH 3/5] add docker-image-amd64.yml
---
.github/workflows/docker-image-amd64.yml | 54 +++++++-----------------
1 file changed, 16 insertions(+), 38 deletions(-)
diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml
index e3b8439ab..1ab220c67 100644
--- a/.github/workflows/docker-image-amd64.yml
+++ b/.github/workflows/docker-image-amd64.yml
@@ -1,54 +1,32 @@
-name: Publish Docker image (amd64)
+name: Docker Image CI
on:
push:
- tags:
- - '*'
- workflow_dispatch:
- inputs:
- name:
- description: 'reason'
- required: false
+ branches: [ "main" ]
+ pull_request:
+ branches: [ "main" ]
+
jobs:
- push_to_registries:
- name: Push Docker image to multiple registries
- runs-on: ubuntu-latest
- permissions:
- packages: write
- contents: read
- steps:
- - name: Check out the repo
- uses: actions/checkout@v3
- - name: Save version info
- run: |
- git describe --tags > VERSION
+ build:
+
+ runs-on: ubuntu-latest
- - name: Log in to Docker Hub
- uses: docker/login-action@v2
+ steps:
+ - uses: actions/checkout@v3
+ - uses: docker/login-action@v3.0.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
- - name: Log in to the Container registry
- uses: docker/login-action@v2
- with:
- registry: ghcr.io
- username: ${{ github.actor }}
- password: ${{ secrets.GITHUB_TOKEN }}
-
- name: Extract metadata (tags, labels) for Docker
id: meta
- uses: docker/metadata-action@v4
+ uses: docker/metadata-action@v3
with:
- images: |
- justsong/one-api
- ghcr.io/${{ github.repository }}
-
- - name: Build and push Docker images
- uses: docker/build-push-action@v3
+ images: calciumion/neko-api
+ - name: Build the Docker image
+ uses: docker/build-push-action@v5.0.0
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels }}
\ No newline at end of file
+ labels: ${{ steps.meta.outputs.labels }}
From 16ad764f9b7162b9032435b2e6163f7157281b52 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 15 Nov 2023 18:27:13 +0800
Subject: [PATCH 4/5] try to fix email
---
controller/relay-audio.go | 2 +-
controller/relay-image.go | 2 +-
controller/relay-mj.go | 2 +-
controller/relay-text.go | 4 ++--
model/token.go | 26 ++++++++++++++++----------
5 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/controller/relay-audio.go b/controller/relay-audio.go
index fe91dbc63..13d9c9fdd 100644
--- a/controller/relay-audio.go
+++ b/controller/relay-audio.go
@@ -99,7 +99,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota)
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/controller/relay-image.go b/controller/relay-image.go
index 5cebcdb13..8c16ec1af 100644
--- a/controller/relay-image.go
+++ b/controller/relay-image.go
@@ -147,7 +147,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
- err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0)
+ err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/controller/relay-mj.go b/controller/relay-mj.go
index 948c57c01..89b0f0c84 100644
--- a/controller/relay-mj.go
+++ b/controller/relay-mj.go
@@ -359,7 +359,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
defer func(ctx context.Context) {
if consumeQuota {
- err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0)
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/controller/relay-text.go b/controller/relay-text.go
index 6f56be8ff..272965008 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -400,7 +400,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
- err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0)
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
@@ -434,7 +434,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quota = 0
}
quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota)
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
diff --git a/model/token.go b/model/token.go
index 06b977502..5c4bc5550 100644
--- a/model/token.go
+++ b/model/token.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "strconv"
"strings"
)
@@ -194,22 +195,31 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
return 0, err
}
if userQuota < quota {
- return userQuota, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
+ return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
- return userQuota, err
+ return 0, err
}
}
err = DecreaseUserQuota(token.UserId, quota)
- return userQuota, err
+ return userQuota - quota, err
}
-func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int) (err error) {
+func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
+ err = DecreaseUserQuota(token.UserId, quota)
+ } else {
+ err = IncreaseUserQuota(token.UserId, -quota)
+ }
+ if err != nil {
+ return err
+ }
+
+ if sendEmail {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
@@ -229,16 +239,12 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
if err != nil {
common.SysError("failed to send email" + err.Error())
}
+ common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
- err = DecreaseUserQuota(token.UserId, quota)
- } else {
- err = IncreaseUserQuota(token.UserId, -quota)
- }
- if err != nil {
- return err
}
+
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
From 63cd3f05f2a6e4e583a132206eeefb0dc15ec21e Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 15 Nov 2023 21:05:14 +0800
Subject: [PATCH 5/5] support tts
---
common/model-ratio.go | 4 ++-
common/utils.go | 9 +++++++
controller/relay-audio.go | 56 ++++++++++++++++++++++++++++++++-------
controller/relay-utils.go | 9 +++++++
controller/relay.go | 6 +++++
middleware/distributor.go | 9 ++++---
router/relay-router.go | 1 +
7 files changed, 81 insertions(+), 13 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index f1cc07dab..820f22833 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
- "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "tts-1": 7.5, // 1k characters -> $0.015
+ "tts-1-hd": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
diff --git a/common/utils.go b/common/utils.go
index 21bec8f59..d65d42a67 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -207,3 +207,12 @@ func String2Int(str string) int {
}
return num
}
+
+func StringsContains(strs []string, str string) bool {
+ for _, s := range strs {
+ if s == str {
+ return true
+ }
+ }
+ return false
+}
diff --git a/controller/relay-audio.go b/controller/relay-audio.go
index 13d9c9fdd..e959e73df 100644
--- a/controller/relay-audio.go
+++ b/controller/relay-audio.go
@@ -11,10 +11,19 @@ import (
"net/http"
"one-api/common"
"one-api/model"
+ "strings"
)
+var availableVoices = []string{
+ "alloy",
+ "echo",
+ "fable",
+ "onyx",
+ "nova",
+ "shimmer",
+}
+
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
- audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
@@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
userId := c.GetInt("id")
group := c.GetString("group")
+ var audioRequest AudioRequest
+ err := common.UnmarshalBodyReusable(c, &audioRequest)
+ if err != nil {
+ return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+ }
+
+ // request validation
+ if audioRequest.Model == "" {
+ return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+ }
+
+ if strings.HasPrefix(audioRequest.Model, "tts-1") {
+ if audioRequest.Voice == "" {
+ return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
+ }
+ if !common.StringsContains(availableVoices, audioRequest.Voice) {
+ return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
+ }
+ }
+
preConsumedTokens := common.PreConsumedQuota
- modelRatio := common.GetModelRatio(audioModel)
+ modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
@@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
- if modelMap[audioModel] != "" {
- audioModel = modelMap[audioModel]
+ if modelMap[audioRequest.Model] != "" {
+ audioRequest.Model = modelMap[audioRequest.Model]
}
}
@@ -97,7 +126,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
defer func(ctx context.Context) {
go func() {
- quota := countTokenText(audioResponse.Text, audioModel)
+ var quota int
+ if strings.HasPrefix(audioRequest.Model, "tts-1") {
+ quota = countAudioToken(audioRequest.Input, audioRequest.Model)
+ } else {
+ quota = countAudioToken(audioResponse.Text, audioRequest.Model)
+ }
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
@@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
- err = json.Unmarshal(responseBody, &audioResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ if strings.HasPrefix(audioRequest.Model, "tts-1") {
+
+ } else {
+ err = json.Unmarshal(responseBody, &audioResponse)
+ if err != nil {
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index d2f3d2fa3..40aa54726 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -10,6 +10,7 @@ import (
"one-api/common"
"strconv"
"strings"
+ "unicode/utf8"
)
var stopFinishReason = "stop"
@@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int {
return 0
}
+func countAudioToken(text string, model string) int {
+ if strings.HasPrefix(model, "tts") {
+ return utf8.RuneCountInString(text)
+ } else {
+ return countTokenText(text, model)
+ }
+}
+
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
diff --git a/controller/relay.go b/controller/relay.go
index c505c22e7..2ca2bc2da 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
return input
}
+type AudioRequest struct {
+ Model string `json:"model"`
+ Voice string `json:"voice"`
+ Input string `json:"input"`
+}
+
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
diff --git a/middleware/distributor.go b/middleware/distributor.go
index c49a40d29..c9d8be8cf 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) {
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
}
- } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
}
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
@@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) {
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
- modelRequest.Model = "whisper-1"
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+ modelRequest.Model = "tts-1"
+ } else {
+ modelRequest.Model = "whisper-1"
+ }
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
diff --git a/router/relay-router.go b/router/relay-router.go
index c97ea31d3..3916503e1 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
+ relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)