diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36798711a9..3f854868e8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,19 +1,17 @@ name: CI # This setup assumes that you run the unit tests with code coverage in the same -# workflow that will also print the coverage report as comment to the pull request. +# workflow that will also print the coverage report as comment to the pull request. # Therefore, you need to trigger this workflow when a pull request is (re)opened or # when new code is pushed to the branch of the pull request. In addition, you also -# need to trigger this workflow when new code is pushed to the main branch because +# need to trigger this workflow when new code is pushed to the main branch because # we need to upload the code coverage results as artifact for the main branch as # well since it will be the baseline code coverage. -# +# # We do not want to trigger the workflow for pushes to *any* branch because this # would trigger our jobs twice on pull requests (once from "push" event and once # from "pull_request->synchronize") on: - pull_request: - types: [opened, reopened, synchronize] push: branches: - 'main' @@ -31,7 +29,7 @@ jobs: with: go-version: ^1.22 - # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a + # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") # in the next step as well as the next job. - name: Test diff --git a/.gitignore b/.gitignore index ccc1c05518..aba380ff6c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ logs data /web/node_modules cmd.md -.env \ No newline at end of file +.env +/one-api diff --git a/README.md b/README.md index 5f9947b0a7..853ec06794 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [together.ai](https://www.together.ai/) + [x] [novita.ai](https://www.novita.ai/) + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) + + [x] [xAI](https://x.ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -114,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 - + 支持使用飞书进行授权登录。 - + [GitHub 开放授权](https://github.com/settings/applications/new)。 + + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 + + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 @@ -174,6 +175,10 @@ sudo service nginx restart 初始账号用户名为 `root`,密码为 `123456`。 +### 通过宝塔面板进行一键部署 +1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; +2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; +3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; ### 基于 Docker Compose 进行部署 @@ -217,7 +222,7 @@ docker-compose ps 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 -6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 +6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 环境变量的具体使用方法详见[此处](#环境变量)。 @@ -346,6 +351,11 @@ graph LR 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 + + 如果需要使用哨兵或者集群模式: + + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 + + 除此之外还需要设置以下环境变量: + + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 + + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + 例子:`SESSION_SECRET=random_string` 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 @@ -399,6 +409,7 @@ graph LR 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 +29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/config/config.go b/common/config/config.go index 231dfde5f8..8235e3b167 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -74,6 +75,13 @@ var GoogleClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" @@ -156,3 +164,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) + +var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 33f7355719..ffade3e53e 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( BaseURL = "base_url" AvailableModels = "available_models" KeyRequestBody = "key_request_body" + SystemPrompt = "system_prompt" ) diff --git a/common/helper/helper.go b/common/helper/helper.go index e06dfb6e64..df7b0a5f9c 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -137,3 +137,23 @@ func String2Int(str string) int { } return num } + +func Float64PtrMax(p *float64, maxValue float64) *float64 { + if p == nil { + return nil + } + if *p > maxValue { + return &maxValue + } + return p +} + +func Float64PtrMin(p *float64, minValue float64) *float64 { + if p == nil { + return nil + } + if *p < minValue { + return &minValue + } + return p +} diff --git a/common/redis.go b/common/redis.go index bb09f5e47c..55d4931c92 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,44 +2,46 @@ package common import ( "context" - "github.com/go-redis/redis/v8" - "github.com/songquanpeng/one-api/common/logger" "os" + "strings" "time" + + "github.com/go-redis/redis/v8" + "github.com/songquanpeng/one-api/common/logger" ) -var RDB *redis.Client +var RDB redis.Cmdable var RedisEnabled = true // InitRedisClient This function is called after init() func InitRedisClient() (err error) { - //if os.Getenv("REDIS_CONN_STRING") == "" { - // RedisEnabled = false - // logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") - // return nil - //} - //if os.Getenv("SYNC_FREQUENCY") == "" { - // RedisEnabled = false - // logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") - // return nil - //} - //logger.SysLog("Redis is enabled") - //opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) - //if err != nil { - // logger.FatalLog("failed to parse Redis connection string: " + err.Error()) - //} - if os.Getenv("REDIS_HOST") == "" { + if os.Getenv("REDIS_CONN_STRING") == "" { RedisEnabled = false - logger.SysLog("REDIS_HOST not set, Redis is not enabled") + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } - opt := &redis.Options{ - Addr: os.Getenv("REDIS_HOST"), - Password: os.Getenv("REDIS_PASSWORD"), - DB: 0, + if os.Getenv("SYNC_FREQUENCY") == "" { + RedisEnabled = false + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") + return nil + } + redisConnString := os.Getenv("REDIS_CONN_STRING") + if os.Getenv("REDIS_MASTER_NAME") == "" { + logger.SysLog("Redis is enabled") + opt, err := redis.ParseURL(redisConnString) + if err != nil { + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + } + RDB = redis.NewClient(opt) + } else { + // cluster mode + logger.SysLog("Redis cluster mode enabled") + RDB = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: strings.Split(redisConnString, ","), + Password: os.Getenv("REDIS_PASSWORD"), + MasterName: os.Getenv("REDIS_MASTER_NAME"), + }) } - RDB = redis.NewClient(opt) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -50,6 +52,14 @@ func InitRedisClient() (err error) { return err } +func ParseRedisOption() *redis.Options { + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + } + return opt +} + func RedisSet(key string, value string, expiration time.Duration) error { ctx := context.Background() return RDB.Set(ctx, key, value, expiration).Err() diff --git a/common/render/render.go b/common/render/render.go index d2c37acdab..eb43b44135 100644 --- a/common/render/render.go +++ b/common/render/render.go @@ -3,9 +3,10 @@ package render import ( "encoding/json" "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "strings" ) func RawData(c *gin.Context, str string) { diff --git a/controller/auth/lark.go b/controller/auth/lark.go index eb06dde9f1..39088b3cc5 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { if err != nil { return nil, err } - req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 0000000000..7b4ad4b9ee --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/billing.go b/controller/billing.go index 4e952a505d..63ba9b3bf6 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { if config.DisplayTokenStatEnabled { tokenId := c.GetInt(ctxkey.TokenId) token, err = model.GetTokenById(tokenId) - expiredTime = token.ExpiredTime - remainQuota = token.RemainQuota - usedQuota = token.UsedQuota + if err == nil { + expiredTime = token.ExpiredTime + remainQuota = token.RemainQuota + usedQuota = token.UsedQuota + } } else { userId := c.GetInt(ctxkey.Id) remainQuota, err = model.GetUserQuota(userId) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 535927444e..e69cd9c256 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,16 +4,17 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strconv" + "time" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/channeltype" - "io" - "net/http" - "strconv" - "time" "github.com/gin-gonic/gin" ) @@ -81,6 +82,36 @@ type APGC2DGPTUsageResponse struct { TotalUsed float64 `json:"total_used"` } +type SiliconFlowUsageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Status bool `json:"status"` + Data struct { + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + Email string `json:"email"` + IsAdmin bool `json:"isAdmin"` + Balance string `json:"balance"` + Status string `json:"status"` + Introduction string `json:"introduction"` + Role string `json:"role"` + ChargeBalance string `json:"chargeBalance"` + TotalBalance string `json:"totalBalance"` + Category string `json:"category"` + } `json:"data"` +} + +type DeepSeekUsageResponse struct { + IsAvailable bool `json:"is_available"` + BalanceInfos []struct { + Currency string `json:"currency"` + TotalBalance string `json:"total_balance"` + GrantedBalance string `json:"granted_balance"` + ToppedUpBalance string `json:"topped_up_balance"` + } `json:"balance_infos"` +} + // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} @@ -203,6 +234,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { return response.TotalAvailable, nil } +func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { + url := "https://api.siliconflow.cn/v1/user/info" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := SiliconFlowUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Code != 20000 { + return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) + } + balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { + url := "https://api.deepseek.com/user/balance" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := DeepSeekUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + index := -1 + for i, balanceInfo := range response.BalanceInfos { + if balanceInfo.Currency == "CNY" { + index = i + break + } + } + if index == -1 { + return 0, errors.New("currency CNY not found") + } + balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { @@ -227,6 +309,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return updateChannelAPI2GPTBalance(channel) case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) + case channeltype.SiliconFlow: + return updateChannelSiliconFlowBalance(channel) + case channeltype.DeepSeek: + return updateChannelDeepSeekBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/channel-test.go b/controller/channel-test.go index 0d3837a49d..57b106716c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -76,9 +76,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques if len(modelNames) > 0 { modelName = modelNames[0] } - if modelMap != nil && modelMap[modelName] != "" { - modelName = modelMap[modelName] - } + } + if modelMap != nil && modelMap[modelName] != "" { + modelName = modelMap[modelName] } meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb33..ae90087017 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/go.mod b/go.mod index 8f6a9cdefe..2bd205ca72 100644 --- a/go.mod +++ b/go.mod @@ -21,13 +21,20 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 + github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.7 + github.com/shopspring/decimal v1.4.0 + github.com/smartwalle/alipay/v3 v3.2.22 + github.com/smartwalle/xid v1.0.7 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 - golang.org/x/crypto v0.24.0 + github.com/stripe/stripe-go/v81 v81.0.0 + golang.org/x/crypto v0.31.0 + golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 golang.org/x/image v0.18.0 + golang.org/x/sync v0.10.0 google.golang.org/api v0.187.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.5.6 @@ -77,7 +84,6 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect @@ -88,14 +94,10 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/shopspring/decimal v1.4.0 // indirect - github.com/smartwalle/alipay/v3 v3.2.22 // indirect github.com/smartwalle/ncrypto v1.0.4 // indirect github.com/smartwalle/ngx v1.0.9 // indirect github.com/smartwalle/nsign v1.0.9 // indirect - github.com/smartwalle/xid v1.0.7 // indirect github.com/smarty/assertions v1.15.0 // indirect - github.com/stripe/stripe-go/v81 v81.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect go.opencensus.io v0.24.0 // indirect @@ -107,9 +109,8 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect diff --git a/go.sum b/go.sum index 942110baf6..3070da5ad1 100644 --- a/go.sum +++ b/go.sum @@ -239,9 +239,11 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo= +golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -262,8 +264,8 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -272,14 +274,14 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/middleware/distributor.go b/middleware/distributor.go index ffba0226f7..ec0284fb7f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -67,6 +67,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type")) + if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { + c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) + } c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/gzip.go b/middleware/gzip.go new file mode 100644 index 0000000000..4d4ce0c255 --- /dev/null +++ b/middleware/gzip.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "compress/gzip" + "github.com/gin-gonic/gin" + "io" + "net/http" +) + +func GzipDecodeMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader("Content-Encoding") == "gzip" { + gzipReader, err := gzip.NewReader(c.Request.Body) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + defer gzipReader.Close() + + // Replace the request body with the decompressed data + c.Request.Body = io.NopCloser(gzipReader) + } + + // Continue processing the request + c.Next() + } +} diff --git a/model/channel.go b/model/channel.go index c4762df9e2..e4d01dd4b8 100644 --- a/model/channel.go +++ b/model/channel.go @@ -37,6 +37,7 @@ type Channel struct { ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Config string `json:"config"` + SystemPrompt *string `json:"system_prompt" gorm:"type:text"` } type ChannelConfig struct { diff --git a/model/log.go b/model/log.go index 0c14d01000..5e998442e1 100644 --- a/model/log.go +++ b/model/log.go @@ -178,7 +178,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { // @deprecated func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { - tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") + ifnull := "ifnull" + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) if username != "" { tx = tx.Where("username = ?", username) } diff --git a/model/option.go b/model/option.go index 017561838b..a5c4139945 100644 --- a/model/option.go +++ b/model/option.go @@ -29,6 +29,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -134,6 +135,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -184,6 +187,18 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/token.go b/model/token.go index 69bdd19c1d..dfb12ceb53 100644 --- a/model/token.go +++ b/model/token.go @@ -32,7 +32,7 @@ type Token struct { RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota - Models *string `json:"models" gorm:"default:''"` // allowed models + Models *string `json:"models" gorm:"type:text"` // allowed models Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } @@ -130,30 +130,40 @@ func GetTokenById(id int) (*Token, error) { return &token, err } -func (token *Token) Insert() error { +func (t *Token) Insert() error { var err error - err = DB.Create(token).Error + err = DB.Create(t).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values -func (token *Token) Update() error { +func (t *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error + err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error return err } -func (token *Token) SelectUpdate() error { +func (t *Token) SelectUpdate() error { // This can update zero values - return DB.Model(token).Select("accessed_time", "status").Updates(token).Error + return DB.Model(t).Select("accessed_time", "status").Updates(t).Error } -func (token *Token) Delete() error { +func (t *Token) Delete() error { var err error - err = DB.Delete(token).Error + err = DB.Delete(t).Error return err } +func (t *Token) GetModels() string { + if t == nil { + return "" + } + if t.Models == nil { + return "" + } + return *t.Models +} + func DeleteTokenById(id int, userId int) (err error) { // Why we need userId here? In case user want to delete other's token. if id == 0 || userId == 0 { @@ -260,14 +270,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) + if err != nil { + return err + } if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) } else { err = IncreaseUserQuota(token.UserId, -quota) } - if err != nil { - return err - } if !token.UnlimitedQuota { if quota > 0 { err = DecreaseTokenQuota(tokenId, quota) diff --git a/model/user.go b/model/user.go index 4f4950ae0a..ed1404ca71 100644 --- a/model/user.go +++ b/model/user.go @@ -41,6 +41,7 @@ type User struct { GoogleId string `json:"google_id" gorm:"column:google_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -260,6 +261,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -296,6 +305,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/monitor/manage.go b/monitor/manage.go index fd67ea6342..d2f3c1f6d2 100644 --- a/monitor/manage.go +++ b/monitor/manage.go @@ -1,10 +1,11 @@ package monitor import ( - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/relay/model" "net/http" "strings" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" ) func ShouldDisableChannel(err *model.ErrorWithStatusCode, statusCode int) bool { @@ -21,31 +22,23 @@ func ShouldDisableChannel(err *model.ErrorWithStatusCode, statusCode int) bool { return true } switch err.Type { - case "insufficient_quota": - return true - // https://docs.anthropic.com/claude/reference/errors - case "authentication_error": - return true - case "permission_error": - return true - case "forbidden": + case "insufficient_quota", "authentication_error", "permission_error", "forbidden": return true } if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } - if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { - return true - } - //if strings.Contains(err.Message, "quota") { - // return true - //} - if strings.Contains(err.Message, "credit") { - return true - } - if strings.Contains(err.Message, "balance") { + + lowerMessage := strings.ToLower(err.Message) + if strings.Contains(lowerMessage, "your access was terminated") || + strings.Contains(lowerMessage, "violation of our policies") || + strings.Contains(lowerMessage, "your credit balance is too low") || + strings.Contains(lowerMessage, "organization has been disabled") || + strings.Contains(lowerMessage, "credit") || + strings.Contains(lowerMessage, "balance") || + strings.Contains(lowerMessage, "permission denied") || + strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "已欠费") { return true } return false diff --git a/relay/adaptor.go b/relay/adaptor.go index 711e63bdc6..03e8390319 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -16,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/proxy" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" "github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &vertexai.Adaptor{} case apitype.Proxy: return &proxy.Adaptor{} + case apitype.Replicate: + return &replicate.Adaptor{} } return nil } diff --git a/relay/adaptor/ali/constants.go b/relay/adaptor/ali/constants.go index 3f24ce2e14..f3d9952000 100644 --- a/relay/adaptor/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -1,7 +1,23 @@ package ali var ModelList = []string{ - "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", - "text-embedding-v1", + "qwen-turbo", "qwen-turbo-latest", + "qwen-plus", "qwen-plus-latest", + "qwen-max", "qwen-max-latest", + "qwen-max-longcontext", + "qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", + "qwen-vl-ocr", "qwen-vl-ocr-latest", + "qwen-audio-turbo", + "qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", + "qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", + "qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", + "qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", + "qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", + "qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", + "qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", + "qwen2-audio-instruct", "qwen-audio-chat", + "qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", + "qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", + "text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 976e0d8d5c..74a402c4ef 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,6 +3,7 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/render" "io" "net/http" @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - if request.TopP >= 1 { - request.TopP = 0.9999 - } + request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) return &ChatRequest{ Model: aliModel, Input: Input{ @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ - Model: "text-embedding-v1", + Model: request.Model, Input: struct { Texts []string `json:"texts"` }{ @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat StatusCode: resp.StatusCode, }, nil } - + requestModel := c.GetString(ctxkey.RequestModel) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + fullTextResponse.Model = requestModel jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go index 450b5f5292..a680c7e24b 100644 --- a/relay/adaptor/ali/model.go +++ b/relay/adaptor/ali/model.go @@ -16,13 +16,13 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` ResultFormat string `json:"result_format,omitempty"` Tools []model.Tool `json:"tools,omitempty"` } diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index 54c749e828..8ea7c4d878 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -3,9 +3,10 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-3-haiku-20240307", + "claude-3-5-haiku-20241022", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", + "claude-3-5-sonnet-latest", } diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index d335e5cffe..4d03a4e471 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -60,8 +60,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 2c254c5e4d..dc9704333c 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -29,11 +29,12 @@ var AwsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", "claude-2.1": "anthropic.claude-v2:1", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", } diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go index 6d00b68865..106228877b 100644 --- a/relay/adaptor/aws/claude/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -11,8 +11,8 @@ type Request struct { Messages []anthropic.Message `json:"messages"` System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go index 7b86c3b8ff..6cb64cdeac 100644 --- a/relay/adaptor/aws/llama3/model.go +++ b/relay/adaptor/aws/llama3/model.go @@ -4,10 +4,10 @@ package aws // // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html type Request struct { - Prompt string `json:"prompt"` - MaxGenLen int `json:"max_gen_len,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } // Response is the response from AWS Llama3 diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index 26cf9ef47a..da2675ccdd 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -35,9 +35,9 @@ type Message struct { type ChatRequest struct { Messages []Message `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - PenaltyScore float64 `json:"penalty_score,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PenaltyScore *float64 `json:"penalty_score,omitempty"` Stream bool `json:"stream,omitempty"` System string `json:"system,omitempty"` DisableSearch bool `json:"disable_search,omitempty"` diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0d3bafe098..8e382ba7ad 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -9,5 +9,5 @@ type Request struct { Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 45db437b6b..736c5a8d86 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { K: textRequest.TopK, Stream: textRequest.Stream, FrequencyPenalty: textRequest.FrequencyPenalty, - PresencePenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.PresencePenalty, Seed: int(textRequest.Seed), } if cohereRequest.Model == "" { diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go index 64fa9c9403..3a8bc99dc7 100644 --- a/relay/adaptor/cohere/model.go +++ b/relay/adaptor/cohere/model.go @@ -10,15 +10,15 @@ type Request struct { PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" Connectors []Connector `json:"connectors,omitempty"` Documents []Document `json:"documents,omitempty"` - Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 MaxTokens int `json:"max_tokens,omitempty"` MaxInputTokens int `json:"max_input_tokens,omitempty"` K int `json:"k,omitempty"` // 默认值为0 - P float64 `json:"p,omitempty"` // 默认值为0.75 + P *float64 `json:"p,omitempty"` // 默认值为0.75 Seed int `json:"seed,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 - PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 Tools []Tool `json:"tools,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"` } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 12f48c715a..a86fde40b8 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) + defaultVersion := config.GeminiVersion + if meta.ActualModelName == "gemini-2.0-flash-exp" { + defaultVersion = "v1beta" + } + + version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) action := "" switch meta.Mode { case relaymode.Embeddings: @@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if meta.IsStream { action = "streamGenerateContent?alt=sse" } + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go index b0f84dfc55..9d1cbc4acd 100644 --- a/relay/adaptor/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -3,5 +3,9 @@ package gemini // https://ai.google.dev/models/gemini var ModelList = []string{ - "gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", + "gemini-pro", "gemini-1.0-pro", + "gemini-1.5-flash", "gemini-1.5-pro", + "text-embedding-004", "aqa", + "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp", } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 703c0f8036..b798942fa2 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,11 +4,12 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -28,6 +29,11 @@ const ( VisionMaxImageNum = 16 ) +var mimeTypeMap = map[string]string{ + "json_object": "application/json", + "text": "text/plain", +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ @@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: config.GeminiSafetySetting, }, + { + Category: "HARM_CATEGORY_CIVIC_INTEGRITY", + Threshold: config.GeminiSafetySetting, + }, }, GenerationConfig: ChatGenerationConfig{ Temperature: textRequest.Temperature, @@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } + if textRequest.ResponseFormat != nil { + if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { + geminiRequest.GenerationConfig.ResponseMimeType = mimeType + } + if textRequest.ResponseFormat.JsonSchema != nil { + geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema + geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] + } + } if textRequest.Tools != nil { functions := make([]model.Function, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { if candidate.Content.Parts[0].FunctionCall != nil { choice.Message.ToolCalls = getToolCalls(&candidate) } else { - choice.Message.Content = candidate.Content.Parts[0].Text + var builder strings.Builder + for _, part := range candidate.Content.Parts { + if i > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } + choice.Message.Content = builder.String() } } else { choice.Message.Content = "" diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index f7179ea48e..720cb65d19 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -65,10 +65,12 @@ type ChatTools struct { } type ChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` } diff --git a/relay/adaptor/groq/constants.go b/relay/adaptor/groq/constants.go index 94b6c5fe6f..0864ebe75e 100644 --- a/relay/adaptor/groq/constants.go +++ b/relay/adaptor/groq/constants.go @@ -4,15 +4,24 @@ package groq var ModelList = []string{ "gemma-7b-it", - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", "gemma2-9b-it", - "llama-3.1-405b-reasoning", "llama-3.1-70b-versatile", "llama-3.1-8b-instant", + "llama-3.2-11b-text-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-1b-preview", + "llama-3.2-3b-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-90b-text-preview", + "llama-3.2-90b-vision-preview", + "llama-guard-3-8b", + "llama3-70b-8192", + "llama3-8b-8192", "llama3-groq-70b-8192-tool-use-preview", "llama3-groq-8b-8192-tool-use-preview", - "whisper-large-v3", + "llava-v1.5-7b-4096-preview", + "mixtral-8x7b-32768", "distil-whisper-large-v3-en", + "whisper-large-v3", + "whisper-large-v3-turbo", } diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 7039984fcc..94f2ab7332 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -1,14 +1,14 @@ package ollama type Options struct { - Seed int `json:"seed,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } type Message struct { diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index da820b9092..84dd646d16 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -77,6 +77,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } + if request.Stream { + // always return usage in stream mode + if request.StreamOptions == nil { + request.StreamOptions = &model.StreamOptions{} + } + request.StreamOptions.IncludeUsage = true + } return request, nil } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go index 0512f05ca7..15b4dcc032 100644 --- a/relay/adaptor/openai/compatible.go +++ b/relay/adaptor/openai/compatible.go @@ -11,9 +11,10 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" "github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/togetherai" - "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" + "github.com/songquanpeng/one-api/relay/adaptor/xai" "github.com/songquanpeng/one-api/relay/channeltype" ) @@ -32,6 +33,7 @@ var CompatibleChannels = []int{ channeltype.TogetherAI, channeltype.Novita, channeltype.SiliconFlow, + channeltype.XAI, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -64,6 +66,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "novita", novita.ModelList case channeltype.SiliconFlow: return "siliconflow", siliconflow.ModelList + case channeltype.XAI: + return "xai", xai.ModelList default: return "openai", ModelList } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index 156a50e7b0..8a643bc6ad 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -8,6 +8,9 @@ var ModelList = []string{ "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-2024-11-20", + "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", @@ -18,4 +21,7 @@ var ModelList = []string{ "dall-e-2", "dall-e-3", "whisper-1", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", + "o1", "o1-2024-12-17", + "o1-preview", "o1-preview-2024-09-12", + "o1-mini", "o1-mini-2024-09-12", } diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go index 7d73303b8d..47c2a882b7 100644 --- a/relay/adaptor/openai/helper.go +++ b/relay/adaptor/openai/helper.go @@ -2,15 +2,16 @@ package openai import ( "fmt" + "strings" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" - "strings" ) -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { +func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { usage := &model.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.CompletionTokens = CountTokenText(responseText, modelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage } diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index a7e8af54a6..2f22ff4275 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -56,7 +56,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue // just ignore the error } if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { - // but for empty choice, we should not pass it to client, this is for azure + // but for empty choice and no usage, we should not pass it to client, this is for azure continue // just ignore empty choice } render.StringData(c, data) diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go index af6e265a80..b95fc86289 100644 --- a/relay/adaptor/openai/util.go +++ b/relay/adaptor/openai/util.go @@ -1,8 +1,16 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" +) func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) + Error := model.Error{ Message: err.Error(), Type: "Aihubmix_api_error", diff --git a/relay/adaptor/palm/model.go b/relay/adaptor/palm/model.go index f653022c3e..2bdd8f298b 100644 --- a/relay/adaptor/palm/model.go +++ b/relay/adaptor/palm/model.go @@ -19,11 +19,11 @@ type Prompt struct { } type ChatRequest struct { - Prompt Prompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` + Prompt Prompt `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` } type Error struct { diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go new file mode 100644 index 0000000000..0013dedfb7 --- /dev/null +++ b/relay/adaptor/replicate/adaptor.go @@ -0,0 +1,131 @@ +package replicate + +import ( + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return DrawImageRequest{ + Input: ImageInput{ + Steps: 25, + Prompt: request.Prompt, + Guidance: 3, + Seed: int(time.Now().UnixNano()), + SafetyTolerance: 5, + NImages: 1, // replicate will always return 1 image + Width: 1440, + Height: 1440, + AspectRatio: "1:1", + }, + }, nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if !request.Stream { + // TODO: support non-stream mode + return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") + } + + // Build the prompt from OpenAI messages + var promptBuilder strings.Builder + for _, message := range request.Messages { + switch msgCnt := message.Content.(type) { + case string: + promptBuilder.WriteString(message.Role) + promptBuilder.WriteString(": ") + promptBuilder.WriteString(msgCnt) + promptBuilder.WriteString("\n") + default: + } + } + + replicateRequest := ReplicateChatRequest{ + Input: ChatInput{ + Prompt: promptBuilder.String(), + MaxTokens: request.MaxTokens, + Temperature: 1.0, + TopP: 1.0, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + }, + } + + // Map optional fields + if request.Temperature != nil { + replicateRequest.Input.Temperature = *request.Temperature + } + if request.TopP != nil { + replicateRequest.Input.TopP = *request.TopP + } + if request.PresencePenalty != nil { + replicateRequest.Input.PresencePenalty = *request.PresencePenalty + } + if request.FrequencyPenalty != nil { + replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty + } + if request.MaxTokens > 0 { + replicateRequest.Input.MaxTokens = request.MaxTokens + } else if request.MaxTokens == 0 { + replicateRequest.Input.MaxTokens = 500 + } + + return replicateRequest, nil +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send request to replicate") + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + case relaymode.ChatCompletions: + err, usage = ChatHandler(c, resp) + default: + err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "replicate" +} diff --git a/relay/adaptor/replicate/chat.go b/relay/adaptor/replicate/chat.go new file mode 100644 index 0000000000..4051f85cec --- /dev/null +++ b/relay/adaptor/replicate/chat.go @@ -0,0 +1,191 @@ +package replicate + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func ChatHandler(c *gin.Context, resp *http.Response) ( + srvErr *model.ErrorWithStatusCode, usage *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ChatResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ChatResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + if taskData.URLs.Stream == "" { + return errors.New("stream url is empty") + } + + // request stream url + responseText, err := chatStreamHandler(c, taskData.URLs.Stream) + if err != nil { + return errors.Wrap(err, "chat stream handler") + } + + ctxMeta := meta.GetByContext(c) + usage = openai.ResponseText2Usage(responseText, + ctxMeta.ActualModelName, ctxMeta.PromptTokens) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, usage +} + +const ( + eventPrefix = "event: " + dataPrefix = "data: " + done = "[DONE]" +) + +func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { + // request stream endpoint + streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) + if err != nil { + return "", errors.Wrap(err, "new request to stream") + } + + streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-store") + + resp, err := http.DefaultClient.Do(streamReq) + if err != nil { + return "", errors.Wrap(err, "do request to stream") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(resp.Body) + return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + doneRendered := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Handle comments starting with ':' + if strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE fields + if strings.HasPrefix(line, eventPrefix) { + event := strings.TrimSpace(line[len(eventPrefix):]) + var data string + // Read the following lines to get data and id + for scanner.Scan() { + nextLine := scanner.Text() + if nextLine == "" { + break + } + if strings.HasPrefix(nextLine, dataPrefix) { + data = nextLine[len(dataPrefix):] + } else if strings.HasPrefix(nextLine, "id:") { + // id = strings.TrimSpace(nextLine[len("id:"):]) + } + } + + if event == "output" { + render.StringData(c, data) + responseText += data + } else if event == "done" { + render.Done(c) + doneRendered = true + break + } + } + } + + if err := scanner.Err(); err != nil { + return "", errors.Wrap(err, "scan stream") + } + + if !doneRendered { + render.Done(c) + } + + return responseText, nil +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go new file mode 100644 index 0000000000..989142c9e1 --- /dev/null +++ b/relay/adaptor/replicate/constant.go @@ -0,0 +1,58 @@ +package replicate + +// ModelList is a list of models that can be used with Replicate. +// +// https://replicate.com/pricing +var ModelList = []string{ + // ------------------------------------- + // image model + // ------------------------------------- + "black-forest-labs/flux-1.1-pro", + "black-forest-labs/flux-1.1-pro-ultra", + "black-forest-labs/flux-canny-dev", + "black-forest-labs/flux-canny-pro", + "black-forest-labs/flux-depth-dev", + "black-forest-labs/flux-depth-pro", + "black-forest-labs/flux-dev", + "black-forest-labs/flux-dev-lora", + "black-forest-labs/flux-fill-dev", + "black-forest-labs/flux-fill-pro", + "black-forest-labs/flux-pro", + "black-forest-labs/flux-redux-dev", + "black-forest-labs/flux-redux-schnell", + "black-forest-labs/flux-schnell", + "black-forest-labs/flux-schnell-lora", + "ideogram-ai/ideogram-v2", + "ideogram-ai/ideogram-v2-turbo", + "recraft-ai/recraft-v3", + "recraft-ai/recraft-v3-svg", + "stability-ai/stable-diffusion-3", + "stability-ai/stable-diffusion-3.5-large", + "stability-ai/stable-diffusion-3.5-large-turbo", + "stability-ai/stable-diffusion-3.5-medium", + // ------------------------------------- + // language model + // ------------------------------------- + "ibm-granite/granite-20b-code-instruct-8k", + "ibm-granite/granite-3.0-2b-instruct", + "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k", + "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3.1-405b-instruct", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1", + "mistralai/mixtral-8x7b-instruct-v0.1", + // ------------------------------------- + // video model + // ------------------------------------- + // "minimax/video-01", // TODO: implement the adaptor +} diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go new file mode 100644 index 0000000000..3687249a1f --- /dev/null +++ b/relay/adaptor/replicate/image.go @@ -0,0 +1,222 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "image" + "image/png" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "golang.org/x/image/webp" + "golang.org/x/sync/errgroup" +) + +// ImagesEditsHandler just copy response body to client +// +// https://replicate.com/black-forest-labs/flux-fill-pro +// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// c.Writer.WriteHeader(resp.StatusCode) +// for k, v := range resp.Header { +// c.Writer.Header().Set(k, v[0]) +// } + +// if _, err := io.Copy(c.Writer, resp.Body); err != nil { +// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// } +// defer resp.Body.Close() + +// return nil, nil +// } + +var errNextLoop = errors.New("next_loop") + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ImageResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ImageResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed: %s", taskData.Status) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + output, err := taskData.GetOutput() + if err != nil { + return errors.Wrap(err, "get output") + } + if len(output) == 0 { + return errors.New("response output is empty") + } + + var mu sync.Mutex + var pool errgroup.Group + respBody := &openai.ImageResponse{ + Created: taskData.CompletedAt.Unix(), + Data: []openai.ImageData{}, + } + + for _, imgOut := range output { + imgOut := imgOut + pool.Go(func() error { + // download image + downloadReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, imgOut, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + imgResp, err := http.DefaultClient.Do(downloadReq) + if err != nil { + return errors.Wrap(err, "download image") + } + defer imgResp.Body.Close() + + if imgResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(imgResp.Body) + return errors.Errorf("bad status code [%d]%s", + imgResp.StatusCode, string(payload)) + } + + imgData, err := io.ReadAll(imgResp.Body) + if err != nil { + return errors.Wrap(err, "read image") + } + + imgData, err = ConvertImageToPNG(imgData) + if err != nil { + return errors.Wrap(err, "convert image") + } + + mu.Lock() + respBody.Data = append(respBody.Data, openai.ImageData{ + B64Json: fmt.Sprintf("data:image/png;base64,%s", + base64.StdEncoding.EncodeToString(imgData)), + }) + mu.Unlock() + + return nil + }) + } + + if err := pool.Wait(); err != nil { + if len(respBody.Data) == 0 { + return errors.WithStack(err) + } + + logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) + } + + c.JSON(http.StatusOK, respBody) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, nil +} + +// ConvertImageToPNG converts a WebP image to PNG format +func ConvertImageToPNG(webpData []byte) ([]byte, error) { + // bypass if it's already a PNG image + if bytes.HasPrefix(webpData, []byte("\x89PNG")) { + return webpData, nil + } + + // check if is jpeg, convert to png + if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { + img, _, err := image.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode jpeg") + } + + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil + } + + // Decode the WebP image + img, err := webp.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode webp") + } + + // Encode the image as PNG + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil +} diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go new file mode 100644 index 0000000000..dba277eb5f --- /dev/null +++ b/relay/adaptor/replicate/model.go @@ -0,0 +1,159 @@ +package replicate + +import ( + "time" + + "github.com/pkg/errors" +) + +// DrawImageRequest draw image by fluxpro +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type DrawImageRequest struct { + Input ImageInput `json:"input"` +} + +// ImageInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema +type ImageInput struct { + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + ImagePrompt string `json:"image_prompt"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + Interval int `json:"interval" binding:"required,min=1,max=4"` + AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + Seed int `json:"seed"` + NImages int `json:"n_images" binding:"required,min=1,max=8"` + Width int `json:"width" binding:"required,min=256,max=1440"` + Height int `json:"height" binding:"required,min=256,max=1440"` +} + +// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type InpaintingImageByFlusReplicateRequest struct { + Input FluxInpaintingInput `json:"input"` +} + +// FluxInpaintingInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type FluxInpaintingInput struct { + Mask string `json:"mask" binding:"required"` + Image string `json:"image" binding:"required"` + Seed int `json:"seed"` + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + OutputFormat string `json:"output_format"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + PromptUnsampling bool `json:"prompt_unsampling"` +} + +// ImageResponse is response of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type ImageResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input DrawImageRequest `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output any `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs FluxURLs `json:"urls"` + Version string `json:"version"` +} + +func (r *ImageResponse) GetOutput() ([]string, error) { + switch v := r.Output.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + case nil: + return nil, nil + case []interface{}: + // convert []interface{} to []string + ret := make([]string, len(v)) + for idx, vv := range v { + if vvv, ok := vv.(string); ok { + ret[idx] = vvv + } else { + return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) + } + } + + return ret, nil + default: + return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) + } +} + +// FluxMetrics is metrics of ImageResponse +type FluxMetrics struct { + ImageCount int `json:"image_count"` + PredictTime float64 `json:"predict_time"` + TotalTime float64 `json:"total_time"` +} + +// FluxURLs is urls of ImageResponse +type FluxURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +type ReplicateChatRequest struct { + Input ChatInput `json:"input" form:"input" binding:"required"` +} + +// ChatInput is input of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema +type ChatInput struct { + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + MinTokens int `json:"min_tokens"` + Temperature float64 `json:"temperature"` + SystemPrompt string `json:"system_prompt"` + StopSequences string `json:"stop_sequences"` + PromptTemplate string `json:"prompt_template"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` +} + +// ChatResponse is response of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json +type ChatResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input ChatInput `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output []string `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs ChatResponseUrl `json:"urls"` + Version string `json:"version"` +} + +// ChatResponseUrl is task urls of ChatResponse +type ChatResponseUrl struct { + Stream string `json:"stream"` + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go index a82e562b2b..6a2346cac5 100644 --- a/relay/adaptor/stepfun/constants.go +++ b/relay/adaptor/stepfun/constants.go @@ -1,7 +1,13 @@ package stepfun var ModelList = []string{ + "step-1-8k", "step-1-32k", + "step-1-128k", + "step-1-256k", + "step-1-flash", + "step-2-16k", + "step-1v-8k", "step-1v-32k", - "step-1-200k", + "step-1x-medium", } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index be415a94c8..e8631e5f47 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "hunyuan-standard", "hunyuan-standard-256K", "hunyuan-pro", + "hunyuan-vision", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index d72ecf0e25..cbb2b67533 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { Model: &request.Model, Stream: &request.Stream, Messages: messages, - TopP: &request.TopP, - Temperature: &request.Temperature, + TopP: request.TopP, + Temperature: request.Temperature, } } diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index 57f3d43f24..cb911cfea0 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -13,7 +13,12 @@ import ( ) var ModelList = []string{ - "claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229", "claude-3-5-sonnet-v2@20241022", "claude-3-5-haiku@20241022", + "claude-3-haiku@20240307", + "claude-3-sonnet@20240229", + "claude-3-opus@20240229", + "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-haiku@20241022", } const anthropicVersion = "vertex-2023-10-16" diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go index e1bd5dd48d..c08ba460d9 100644 --- a/relay/adaptor/vertexai/claude/model.go +++ b/relay/adaptor/vertexai/claude/model.go @@ -11,8 +11,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index 43e6cbcde3..b537787553 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,10 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", + "gemini-pro", "gemini-pro-vision", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", + "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", } type Adaptor struct { diff --git a/relay/adaptor/xai/constants.go b/relay/adaptor/xai/constants.go new file mode 100644 index 0000000000..9082b999a3 --- /dev/null +++ b/relay/adaptor/xai/constants.go @@ -0,0 +1,5 @@ +package xai + +var ModelList = []string{ + "grok-beta", +} diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 12a5621099..5b82ac292f 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -5,6 +5,8 @@ var ModelList = []string{ "SparkDesk-v1.1", "SparkDesk-v2.1", "SparkDesk-v3.1", + "SparkDesk-v3.1-128K", "SparkDesk-v3.5", + "SparkDesk-v3.5-32K", "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 33f4f75135..0c05706913 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -275,9 +275,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, } func parseAPIVersionByModelName(modelName string) string { - parts := strings.Split(modelName, "-") - if len(parts) == 2 { - return parts[1] + index := strings.IndexAny(modelName, "-") + if index != -1 { + return modelName[index+1:] } return "" } @@ -286,13 +286,17 @@ func parseAPIVersionByModelName(modelName string) string { func apiVersion2domain(apiVersion string) string { switch apiVersion { case "v1.1": - return "general" + return "lite" case "v2.1": return "generalv2" case "v3.1": return "generalv3" + case "v3.1-128K": + return "pro-128k" case "v3.5": return "generalv3.5" + case "v3.5-32K": + return "max-32k" case "v4.0": return "4.0Ultra" } @@ -300,7 +304,17 @@ func apiVersion2domain(apiVersion string) string { } func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { + var authUrl string domain := apiVersion2domain(apiVersion) - authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + switch apiVersion { + case "v3.1-128K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) + break + case "v3.5-32K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) + break + default: + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + } return domain, authUrl } diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index 1f37c04655..c9fb1bb8f2 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -19,11 +19,11 @@ type ChatRequest struct { } `json:"header"` Parameter struct { Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` + Domain string `json:"domain,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` } `json:"chat"` } `json:"parameter"` Payload struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3f7..660bd37960 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "io" - "math" "net/http" "strings" ) @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, err default: - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + // TopP [0.0, 1.0] + request.TopP = helper.Float64PtrMax(request.TopP, 1) + request.TopP = helper.Float64PtrMin(request.TopP, 0) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) + // Temperature [0.0, 1.0] + request.Temperature = helper.Float64PtrMax(request.Temperature, 1) + request.Temperature = helper.Float64PtrMin(request.Temperature, 0) a.SetVersionByModeName(request.Model) if a.APIVersion == "v4" { return request, nil diff --git a/relay/adaptor/zhipu/model.go b/relay/adaptor/zhipu/model.go index f91de1dced..06e22dc153 100644 --- a/relay/adaptor/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -12,8 +12,8 @@ type Message struct { type Request struct { Prompt []Message `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` RequestId string `json:"request_id,omitempty"` Incremental bool `json:"incremental,omitempty"` } diff --git a/relay/apitype/define.go b/relay/apitype/define.go index cf7b6a0d2b..0c6a5ff11a 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -19,6 +19,7 @@ const ( DeepL VertexAI Proxy + Replicate Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go index ced0c6678c..c8c42a15c0 100644 --- a/relay/billing/ratio/image.go +++ b/relay/billing/ratio/image.go @@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{ "720x1280": 1, "1280x720": 1, }, + "step-1x-medium": { + "256x256": 1, + "512x512": 1, + "768x768": 1, + "1024x1024": 1, + "1280x800": 1, + "800x1280": 1, + }, } var ImageGenerationAmounts = map[string][2]int{ @@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{ "ali-stable-diffusion-v1.5": {1, 4}, // Ali "wanx-v1": {1, 4}, // Ali "cogview-3": {1, 1}, + "step-1x-medium": {1, 1}, } var ImagePromptLengthLimitations = map[string]int{ @@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{ "ali-stable-diffusion-v1.5": 4000, "wanx-v1": 4000, "cogview-3": 833, + "step-1x-medium": 4000, } var ImageOriginModelName = map[string]string{ diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 2a653d5832..755680cf0a 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -23,66 +23,77 @@ const ( // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ // https://openai.com/pricing - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens - "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens - "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens + "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, //"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 50, // $0.1 / 1K sec "whisper-large-v3": 15.417, // $0.111 / 1h "distil-whisper-large-v3-en": 2.778, //$0.02 /h - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image - "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-5-haiku-20241022": 1.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, @@ -102,11 +113,15 @@ var ModelRatio = map[string]float64{ "bge-large-en": 0.002 * RMB, "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro": 1, - "gemini-1.5-flash": 1, - "gemini-1.5-pro": 1, - "aqa": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro": 1, + "gemini-1.5-pro": 1, + "gemini-1.5-pro-001": 1, + "gemini-1.5-flash": 1, + "gemini-1.5-flash-001": 1, + "gemini-2.0-flash-exp": 1, + "gemini-2.0-flash-thinking-exp": 1, + "aqa": 1, // https://open.bigmodel.cn/pricing "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, @@ -118,27 +133,94 @@ var ModelRatio = map[string]float64{ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "cogview-3": 0.25 * RMB, // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens - "qwen-plus": 1.4286, // ¥0.02 / 1k tokens - "qwen-max": 1.4286, // ¥0.02 / 1k tokens - "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens - "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens - "ali-stable-diffusion-xl": 8, - "ali-stable-diffusion-v1.5": 8, - "wanx-v1": 8, - "SparkDesk": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens - "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens - "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens - "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - "ChatStd": 0.01 * RMB, - "ChatPro": 0.1 * RMB, + "qwen-turbo": 1.4286, // ¥0.02 / 1k tokens + "qwen-turbo-latest": 1.4286, + "qwen-plus": 1.4286, + "qwen-plus-latest": 1.4286, + "qwen-max": 1.4286, + "qwen-max-latest": 1.4286, + "qwen-max-longcontext": 1.4286, + "qwen-vl-max": 1.4286, + "qwen-vl-max-latest": 1.4286, + "qwen-vl-plus": 1.4286, + "qwen-vl-plus-latest": 1.4286, + "qwen-vl-ocr": 1.4286, + "qwen-vl-ocr-latest": 1.4286, + "qwen-audio-turbo": 1.4286, + "qwen-math-plus": 1.4286, + "qwen-math-plus-latest": 1.4286, + "qwen-math-turbo": 1.4286, + "qwen-math-turbo-latest": 1.4286, + "qwen-coder-plus": 1.4286, + "qwen-coder-plus-latest": 1.4286, + "qwen-coder-turbo": 1.4286, + "qwen-coder-turbo-latest": 1.4286, + "qwq-32b-preview": 1.4286, + "qwen2.5-72b-instruct": 1.4286, + "qwen2.5-32b-instruct": 1.4286, + "qwen2.5-14b-instruct": 1.4286, + "qwen2.5-7b-instruct": 1.4286, + "qwen2.5-3b-instruct": 1.4286, + "qwen2.5-1.5b-instruct": 1.4286, + "qwen2.5-0.5b-instruct": 1.4286, + "qwen2-72b-instruct": 1.4286, + "qwen2-57b-a14b-instruct": 1.4286, + "qwen2-7b-instruct": 1.4286, + "qwen2-1.5b-instruct": 1.4286, + "qwen2-0.5b-instruct": 1.4286, + "qwen1.5-110b-chat": 1.4286, + "qwen1.5-72b-chat": 1.4286, + "qwen1.5-32b-chat": 1.4286, + "qwen1.5-14b-chat": 1.4286, + "qwen1.5-7b-chat": 1.4286, + "qwen1.5-1.8b-chat": 1.4286, + "qwen1.5-0.5b-chat": 1.4286, + "qwen-72b-chat": 1.4286, + "qwen-14b-chat": 1.4286, + "qwen-7b-chat": 1.4286, + "qwen-1.8b-chat": 1.4286, + "qwen-1.8b-longcontext-chat": 1.4286, + "qwen2-vl-7b-instruct": 1.4286, + "qwen2-vl-2b-instruct": 1.4286, + "qwen-vl-v1": 1.4286, + "qwen-vl-chat-v1": 1.4286, + "qwen2-audio-instruct": 1.4286, + "qwen-audio-chat": 1.4286, + "qwen2.5-math-72b-instruct": 1.4286, + "qwen2.5-math-7b-instruct": 1.4286, + "qwen2.5-math-1.5b-instruct": 1.4286, + "qwen2-math-72b-instruct": 1.4286, + "qwen2-math-7b-instruct": 1.4286, + "qwen2-math-1.5b-instruct": 1.4286, + "qwen2.5-coder-32b-instruct": 1.4286, + "qwen2.5-coder-14b-instruct": 1.4286, + "qwen2.5-coder-7b-instruct": 1.4286, + "qwen2.5-coder-3b-instruct": 1.4286, + "qwen2.5-coder-1.5b-instruct": 1.4286, + "qwen2.5-coder-0.5b-instruct": 1.4286, + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "text-embedding-v3": 0.05, + "text-embedding-v2": 0.05, + "text-embedding-async-v2": 0.05, + "text-embedding-async-v1": 0.05, + "ali-stable-diffusion-xl": 8.00, + "ali-stable-diffusion-v1.5": 8.00, + "wanx-v1": 8.00, + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "ChatStd": 0.01 * RMB, + "ChatPro": 0.1 * RMB, // https://platform.moonshot.cn/pricing "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, @@ -162,23 +244,34 @@ var ModelRatio = map[string]float64{ "mistral-embed": 0.1 / 1000 * USD, // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed "gemma-7b-it": 0.07 / 1000000 * USD, - "mixtral-8x7b-32768": 0.24 / 1000000 * USD, - "llama3-8b-8192": 0.05 / 1000000 * USD, - "llama3-70b-8192": 0.59 / 1000000 * USD, "gemma2-9b-it": 0.20 / 1000000 * USD, - "llama-3.1-405b-reasoning": 0.89 / 1000000 * USD, "llama-3.1-70b-versatile": 0.59 / 1000000 * USD, "llama-3.1-8b-instant": 0.05 / 1000000 * USD, + "llama-3.2-11b-text-preview": 0.05 / 1000000 * USD, + "llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD, + "llama-3.2-1b-preview": 0.05 / 1000000 * USD, + "llama-3.2-3b-preview": 0.05 / 1000000 * USD, + "llama-3.2-90b-text-preview": 0.59 / 1000000 * USD, + "llama-guard-3-8b": 0.05 / 1000000 * USD, + "llama3-70b-8192": 0.59 / 1000000 * USD, + "llama3-8b-8192": 0.05 / 1000000 * USD, "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, + "mixtral-8x7b-32768": 0.24 / 1000000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 "yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB, - // stepfun todo - "step-1v-32k": 0.024 * RMB, - "step-1-32k": 0.024 * RMB, - "step-1-200k": 0.15 * RMB, + // https://platform.stepfun.com/docs/pricing/details + "step-1-8k": 0.005 / 1000 * RMB, + "step-1-32k": 0.015 / 1000 * RMB, + "step-1-128k": 0.040 / 1000 * RMB, + "step-1-256k": 0.095 / 1000 * RMB, + "step-1-flash": 0.001 / 1000 * RMB, + "step-2-16k": 0.038 / 1000 * RMB, + "step-1v-8k": 0.005 / 1000 * RMB, + "step-1v-32k": 0.015 / 1000 * RMB, // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens @@ -196,6 +289,52 @@ var ModelRatio = map[string]float64{ "deepl-zh": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD, + // https://console.x.ai/ + "grok-beta": 5.0 / 1000 * USD, + // replicate charges based on the number of generated images + // https://replicate.com/pricing + "black-forest-labs/flux-1.1-pro": 0.04 * USD, + "black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD, + "black-forest-labs/flux-canny-dev": 0.025 * USD, + "black-forest-labs/flux-canny-pro": 0.05 * USD, + "black-forest-labs/flux-depth-dev": 0.025 * USD, + "black-forest-labs/flux-depth-pro": 0.05 * USD, + "black-forest-labs/flux-dev": 0.025 * USD, + "black-forest-labs/flux-dev-lora": 0.032 * USD, + "black-forest-labs/flux-fill-dev": 0.04 * USD, + "black-forest-labs/flux-fill-pro": 0.05 * USD, + "black-forest-labs/flux-pro": 0.055 * USD, + "black-forest-labs/flux-redux-dev": 0.025 * USD, + "black-forest-labs/flux-redux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell-lora": 0.02 * USD, + "ideogram-ai/ideogram-v2": 0.08 * USD, + "ideogram-ai/ideogram-v2-turbo": 0.05 * USD, + "recraft-ai/recraft-v3": 0.04 * USD, + "recraft-ai/recraft-v3-svg": 0.08 * USD, + "stability-ai/stable-diffusion-3": 0.035 * USD, + "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, + "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, + "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, + // replicate chat models + "ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD, + "ibm-granite/granite-3.0-2b-instruct": 0.030 * USD, + "ibm-granite/granite-3.0-8b-instruct": 0.050 * USD, + "ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, + "meta/llama-2-13b": 0.100 * USD, + "meta/llama-2-13b-chat": 0.100 * USD, + "meta/llama-2-70b": 0.650 * USD, + "meta/llama-2-70b-chat": 0.650 * USD, + "meta/llama-2-7b": 0.050 * USD, + "meta/llama-2-7b-chat": 0.050 * USD, + "meta/meta-llama-3.1-405b-instruct": 9.500 * USD, + "meta/meta-llama-3-70b": 0.650 * USD, + "meta/meta-llama-3-70b-instruct": 0.650 * USD, + "meta/meta-llama-3-8b": 0.050 * USD, + "meta/meta-llama-3-8b-instruct": 0.050 * USD, + "mistralai/mistral-7b-instruct-v0.2": 0.050 * USD, + "mistralai/mistral-7b-v0.1": 0.050 * USD, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, } var CompletionRatio = map[string]float64{ @@ -204,8 +343,10 @@ var CompletionRatio = map[string]float64{ "llama3-70b-8192(33)": 0.0035 / 0.00265, } -var DefaultModelRatio map[string]float64 -var DefaultCompletionRatio map[string]float64 +var ( + DefaultModelRatio map[string]float64 + DefaultCompletionRatio map[string]float64 +) type ModelRatioConfig struct { ModelRatio float64 @@ -358,16 +499,25 @@ func GetCompletionRatio(name string, channelType int) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { - if strings.HasPrefix(name, "gpt-4o-mini") { + if strings.HasPrefix(name, "gpt-4o") { + if name == "gpt-4o-2024-05-13" { + return 3 + } return 4 } if strings.HasPrefix(name, "gpt-4-turbo") || - strings.HasPrefix(name, "gpt-4o") || strings.HasSuffix(name, "preview") { return 3 } return 2 } + // including o1, o1-preview, o1-mini + if strings.HasPrefix(name, "o1") { + return 4 + } + if name == "chatgpt-4o-latest" { + return 3 + } if strings.HasPrefix(name, "claude-3") { return 5 } @@ -383,6 +533,7 @@ func GetCompletionRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "deepseek-") { return 2 } + switch name { case "llama2-70b-4096": return 0.8 / 0.64 @@ -396,6 +547,37 @@ func GetCompletionRatio(name string, channelType int) float64 { return 3 case "command-r-plus": return 5 + case "grok-beta": + return 3 + // Replicate Models + // https://replicate.com/pricing + case "ibm-granite/granite-20b-code-instruct-8k": + return 5 + case "ibm-granite/granite-3.0-2b-instruct": + return 8.333333333333334 + case "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k": + return 5 + case "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct": + return 5 + case "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct": + return 2.750 / 0.650 // ≈4.230769 + case "meta/meta-llama-3.1-405b-instruct": + return 1 + case "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1": + return 5 + case "mistralai/mixtral-8x7b-instruct-v0.1": + return 1.000 / 0.300 // ≈3.333333 } + return 1 } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index a261cff85d..f54d0e30de 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -46,5 +46,7 @@ const ( VertextAI Proxy SiliconFlow + XAI + Replicate Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index fae3357f8c..8839b30adb 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { apiType = apitype.DeepL case VertextAI: apiType = apitype.VertexAI + case Replicate: + apiType = apitype.Replicate case Proxy: apiType = apitype.Proxy } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 2b06a7e17e..6de86f19fb 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -46,6 +46,8 @@ var ChannelBaseURLs = []string{ "", // 42 "", // 43 "https://api.siliconflow.cn", // 44 + "https://api.x.ai", // 45 + "https://api.replicate.com/v1/models/", // 46 } func init() { diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go index 972488c5c9..5097c97e21 100644 --- a/relay/constant/role/define.go +++ b/relay/constant/role/define.go @@ -1,5 +1,6 @@ package role const ( + System = "system" Assistant = "assistant" ) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 71a838bf34..bd593b9309 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -112,16 +112,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus }() // map model name - modelMapping := c.GetString(ctxkey.ModelMapping) - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[audioModel] != "" { - audioModel = modelMap[audioModel] - } + modelMapping := c.GetStringMapString(ctxkey.ModelMapping) + if modelMapping != nil && modelMapping[audioModel] != "" { + audioModel = modelMapping[audioModel] } baseURL := channeltype.ChannelBaseURLs[channelType] diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 1b6bfb25d1..2d79d70dc7 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/songquanpeng/one-api/relay/constant/role" "math" "net/http" "strings" @@ -86,7 +87,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return @@ -120,7 +121,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - logContent := fmt.Sprintf("模型倍率 %.3f,分组倍率 %.3f,补全倍率 %.3f", modelRatio, groupRatio, completionRatio) + var extraLog string + if systemPromptReset { + extraLog = " (注意系统提示词已被重置)" + } + logContent := fmt.Sprintf("模型倍率 %.3f,分组倍率 %.3f,补全倍率 %.3f%s", modelRatio, groupRatio, completionRatio, extraLog) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, cachedTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) @@ -144,15 +149,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return true } - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK && + // replicate return 201 to create a task + resp.StatusCode != http.StatusCreated { return true } if meta.ChannelType == channeltype.DeepL { // skip stream check for deepl return false } - if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && + // Even if stream mode is enabled, replicate will first return a task info in JSON format, + // requiring the client to request the stream endpoint in the task info + meta.ChannelType != channeltype.Replicate { return true } return false } + +func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { + if prompt == "" { + return false + } + if len(request.Messages) == 0 { + return false + } + if request.Messages[0].Role == role.System { + request.Messages[0].Content = prompt + logger.Infof(ctx, "rewrite system prompt") + return true + } + request.Messages = append([]relaymodel.Message{{ + Role: role.System, + Content: prompt, + }}, request.Messages...) + logger.Infof(ctx, "add system prompt") + return true +} diff --git a/relay/controller/image.go b/relay/controller/image.go index 104b30dd12..8ca0a4c63b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -24,7 +24,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { +func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { @@ -67,7 +67,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta, relayMode int) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta, relayMode int) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" && (relayMode == relaymode.ImagesEdits || relayMode == relaymode.ImagesGenerations) { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -153,12 +153,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } adaptor.Init(meta) + // these adaptors need to convert the request switch meta.ChannelType { - case channeltype.Ali: - fallthrough - case channeltype.Baidu: - fallthrough - case channeltype.Zhipu: + case channeltype.Zhipu, + channeltype.Ali, + channeltype.Replicate, + channeltype.Baidu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) @@ -175,7 +175,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -189,7 +196,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } defer func(ctx context.Context) { - if resp != nil && resp.StatusCode != http.StatusOK { + if resp != nil && + resp.StatusCode != http.StatusCreated && // replicate returns 201 + resp.StatusCode != http.StatusOK { return } diff --git a/relay/controller/text.go b/relay/controller/text.go index ed7133c19e..70590737fe 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/config" "io" "net/http" "strings" @@ -36,6 +37,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.OriginModelName = textRequest.Model textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model + // set system prompt if not empty + systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) // get model ratio & group ratio modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) @@ -81,12 +84,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return respErr } // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) return nil } func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { - if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan && + if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan && meta.ChannelType != channeltype.Azure && !strings.Contains(meta.BaseURL, "ai.azure.com") { // no need to convert request for openai return c.Request.Body, nil diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index b1761e9a7c..bcbe10453a 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -30,6 +30,7 @@ type Meta struct { ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse + SystemPrompt string } func GetByContext(c *gin.Context) *Meta { @@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta { BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), + SystemPrompt: c.GetString(ctxkey.SystemPrompt), } cfg, ok := c.Get(ctxkey.Config) if ok { diff --git a/relay/model/constant.go b/relay/model/constant.go index f6cf1924d1..c9d6d645c6 100644 --- a/relay/model/constant.go +++ b/relay/model/constant.go @@ -1,6 +1,7 @@ package model const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" + ContentTypeText = "text" + ContentTypeImageURL = "image_url" + ContentTypeInputAudio = "input_audio" ) diff --git a/relay/model/general.go b/relay/model/general.go index c34c1c2d5d..288c07ffb5 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,35 +1,70 @@ package model type ResponseFormat struct { - Type string `json:"type,omitempty"` + Type string `json:"type,omitempty"` + JsonSchema *JSONSchema `json:"json_schema,omitempty"` +} + +type JSONSchema struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema map[string]interface{} `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +type Audio struct { + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` +} + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` } type GeneralOpenAIRequest struct { - Messages []Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Stop any `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - FunctionCall any `json:"function_call,omitempty"` - Functions any `json:"functions,omitempty"` - User string `json:"user,omitempty"` - Prompt any `json:"prompt,omitempty"` - Input any `json:"input,omitempty"` - EncodingFormat string `json:"encoding_format,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + // https://platform.openai.com/docs/api-reference/chat/create + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` + // https://platform.openai.com/docs/api-reference/embeddings/create + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // https://platform.openai.com/docs/api-reference/images/create + Prompt any `json:"prompt,omitempty"` + Quality *string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style *string `json:"style,omitempty"` + // Others + Instruction string `json:"instruction,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/router/api.go b/router/api.go index d4a53af2ab..b43f297710 100644 --- a/router/api.go +++ b/router/api.go @@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) apiRouter.GET("/oauth/google", middleware.CriticalRateLimit(), auth.GoogleOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) diff --git a/router/relay.go b/router/relay.go index d7908f7ab4..d29602c2b7 100644 --- a/router/relay.go +++ b/router/relay.go @@ -9,6 +9,7 @@ import ( func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) + router.Use(middleware.GzipDecodeMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TryTokenAuth()) diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js index 0853ddfbee..48836c859a 100644 --- a/web/air/src/components/TokensTable.js +++ b/web/air/src/components/TokensTable.js @@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken'; const COPY_OPTIONS = [ { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' }, ]; const OPEN_LINK_OPTIONS = [ { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' } ]; function renderTimestamp(timestamp) { @@ -60,7 +62,12 @@ const TokensTable = () => { onOpenLink('next-mj'); } }, - { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } + } ]; const columns = [ @@ -177,6 +184,11 @@ const TokensTable = () => { node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { onOpenLink('opencat', record.key); } + }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } } ] } @@ -382,6 +394,9 @@ const TokensTable = () => { case 'next-mj': url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; break; + case 'lobechat': + url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`; + break; default: if (!chatLink) { showError('管理员未设置聊天链接'); diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index 04fe94f17a..e7b25399b9 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -30,6 +30,8 @@ export const CHANNEL_OPTIONS = [ { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, + { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index 73fd2da200..4a810830bd 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -43,6 +43,7 @@ const EditChannel = (props) => { base_url: '', other: '', model_mapping: '', + system_prompt: '', models: [], auto_ban: 1, groups: ['default'] @@ -63,7 +64,7 @@ const EditChannel = (props) => { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"]; break; case 11: localModels = ['PaLM-2']; @@ -78,7 +79,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; @@ -304,163 +305,163 @@ const EditChannel = (props) => { width={isMobile() ? '100%' : 600} > -
+
类型:
{ - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> -
- 默认 API 版本: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 3 && ( + <> +
+ + 注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 + model + 参数替换为你的部署名称(模型名称中的点会被剔除),图片演示。 + + }> + +
+
+ AZURE_OPENAI_ENDPOINT: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> +
+ 默认 API 版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 8 && ( - <> -
- Base URL: -
- { - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> - - ) + inputs.type === 8 && ( + <> +
+ Base URL: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + ) } -
+
名称:
{ - handleInputChange('name', value) - }} - value={inputs.name} - autoComplete='new-password' + required + name='name' + placeholder={'请为渠道命名'} + onChange={value => { + handleInputChange('name', value) + }} + value={inputs.name} + autoComplete='new-password' /> -
+
分组:
{ - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 18 && ( + <> +
+ 模型版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 21 && ( - <> -
- 知识库 ID: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 21 && ( + <> +
+ 知识库 ID: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } -
+
模型:
填入 - } - placeholder='输入自定义模型名称' - value={customModel} - onChange={(value) => { - setCustomModel(value.trim()); - }} + addonAfter={ + + } + placeholder='输入自定义模型名称' + value={customModel} + onChange={(value) => { + setCustomModel(value.trim()); + }} />
-
+
模型重定向: