diff --git a/controller/relay-claude.go b/controller/relay-claude.go index f2894b45fd..6eed122738 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -113,6 +114,14 @@ func relayTextHelper(c *gin.Context) *relay_model.ErrorWithStatusCode { modelRatio := billingratio.GetModelRatio(request.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio + // pre-consume quota + promptTokens := getPromptTokens(request) + meta.PromptTokens = promptTokens + bizErr := validQuota(ctx, request, promptTokens, ratio, meta) + if bizErr != nil { + logger.Warnf(ctx, "validQuota failed: %+v", *bizErr) + return bizErr + } adaptor := getAdaptor(meta.APIType) usage, bizError := adaptor.DoRequest(c, request, meta) @@ -125,6 +134,32 @@ func relayTextHelper(c *gin.Context) *relay_model.ErrorWithStatusCode { return nil } +func getPromptTokens(request *anthropic.Request) int { + //todo calc input tokens + return 1 +} + +func validQuota(ctx context.Context, request *anthropic.Request, promptTokens int, ratio float64, meta *meta.Meta) *relay_model.ErrorWithStatusCode { + preConsumedQuota := getPreConsumedQuota(request, promptTokens, ratio) + + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) + if err != nil { + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-preConsumedQuota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + return nil +} + +func getPreConsumedQuota(request *anthropic.Request, promptTokens int, ratio float64) int64 { + preConsumedTokens := config.PreConsumedQuota + int64(promptTokens) + if request.MaxTokens != 0 { + preConsumedTokens += int64(request.MaxTokens) + } + return int64(float64(preConsumedTokens) * ratio) +} + func getAndValidateRequest(c *gin.Context, mode int) (*anthropic.Request, error) { request := &anthropic.Request{} err := common.UnmarshalBodyReusable(c, request) diff --git a/model/channel.go b/model/channel.go index 198d053ae2..c4762df9e2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -143,6 +143,7 @@ func (channel *Channel) Insert() error { func (channel *Channel) Update() error { var err error + channel.UsedQuota = 0 err = DB.Model(channel).Updates(channel).Error if err != nil { return err