Skip to content

Commit

Permalink
feat(you): custom model
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Oct 31, 2024
1 parent 262e236 commit 0a44bd3
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 118 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/bincooo/edge-api v1.0.4-0.20240918111026-76a4223e60d0
github.com/bincooo/emit.io v1.0.1-0.20240918104917-7aa3711f2559
github.com/bincooo/vecmul.com v0.0.0-20240918113329-241a0f273998
github.com/bincooo/you.com v0.0.0-20240918111518-4ae3958b355f
github.com/bincooo/you.com v0.0.0-20241031220121-a057b2dc1010
github.com/bogdanfinn/tls-client v1.7.7
github.com/dlclark/regexp2 v1.11.4
github.com/dop251/goja v0.0.0-20240828124009-016eb7256539
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ github.com/bincooo/emit.io v1.0.1-0.20240918104917-7aa3711f2559 h1:JL6cGQ4eXhPRa
github.com/bincooo/emit.io v1.0.1-0.20240918104917-7aa3711f2559/go.mod h1:OJbKJoZ6x6vSpCC+JNtfcaXo3ilpvQscWrcGEmtmrZI=
github.com/bincooo/vecmul.com v0.0.0-20240918113329-241a0f273998 h1:FgFjU/WZPzocFsF3Ltahj+6CbLTu4FAr9c5RZOUy0Xk=
github.com/bincooo/vecmul.com v0.0.0-20240918113329-241a0f273998/go.mod h1:iNsaIde7efj9BBtghS/noQ6OqXz1v/RQ7/NMO1s3Nzw=
github.com/bincooo/you.com v0.0.0-20240918111518-4ae3958b355f h1:GX24CoKOA9w+63QHI0WCU7t2LjWNwzHMx4py/jbMTEM=
github.com/bincooo/you.com v0.0.0-20240918111518-4ae3958b355f/go.mod h1:Rs9F3k8IIMxDSd65XvLdhroB8AcHfvzogY3yv1fZGs0=
github.com/bincooo/you.com v0.0.0-20241031220121-a057b2dc1010 h1:TgsXrQnkMB2amz6qYBZ/tAOotmMJEbTVvIhVNpl2E7A=
github.com/bincooo/you.com v0.0.0-20241031220121-a057b2dc1010/go.mod h1:Rs9F3k8IIMxDSd65XvLdhroB8AcHfvzogY3yv1fZGs0=
github.com/bogdanfinn/fhttp v0.5.28 h1:G6thT8s8v6z1IuvXMUsX9QKy3ZHseTQTzxuIhSiaaAw=
github.com/bogdanfinn/fhttp v0.5.28/go.mod h1:oJiYPG3jQTKzk/VFmogH8jxjH5yiv2rrOH48Xso2lrE=
github.com/bogdanfinn/tls-client v1.7.7 h1:c3mf6LX6bxEsunJhP2BJeJE7qN/7BniWUpIpBc9Igu8=
Expand Down
19 changes: 19 additions & 0 deletions internal/plugin/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type Adapter interface {
Completion(ctx *gin.Context)
Generation(ctx *gin.Context)
Embedding(ctx *gin.Context)
HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error)
}

type BaseAdapter struct {
Expand All @@ -80,6 +81,15 @@ func (BaseAdapter) Generation(*gin.Context) {

func (BaseAdapter) Embedding(*gin.Context) {}

func (BaseAdapter) HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error) {
var (
completion = common.GetGinCompletion(ctx)
)

messages = completion.Messages
return
}

func NewGlobalAdapter() *ExtensionAdapter {
return &ExtensionAdapter{
slice: make([]Adapter, 0),
Expand All @@ -106,6 +116,15 @@ func (adapter *ExtensionAdapter) Completion(ctx *gin.Context) {
completion := common.GetGinCompletion(ctx)
for _, extension := range adapter.slice {
if extension.Match(ctx, completion.Model) {
messages, err := extension.HandleMessages(ctx)
if err != nil {
logger.Error("Error handling messages: ", err)
response.Error(ctx, 500, err)
return
}

completion.Messages = messages
ctx.Set(vars.GinCompletion, completion)
extension.Completion(ctx)
return
}
Expand Down
13 changes: 13 additions & 0 deletions internal/plugin/llm/claude/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ func (API) Models() []plugin.Model {
}
}

func (API) HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error) {
var (
completion = common.GetGinCompletion(ctx)
toolMessages = common.FindToolMessages(&completion)
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)
return
}

func (API) Completion(ctx *gin.Context) {
var (
completion = common.GetGinCompletion(ctx)
Expand Down
10 changes: 2 additions & 8 deletions internal/plugin/llm/claude/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,10 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan

func mergeMessages(ctx *gin.Context) (attachment []claude3.Attachment, tokens int, err error) {
var (
completion = common.GetGinCompletion(ctx)
messages = completion.Messages
toolMessages = common.FindToolMessages(&completion)
completion = common.GetGinCompletion(ctx)
messages = completion.Messages
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)

var (
pos = 0
contents []string
Expand Down
13 changes: 13 additions & 0 deletions internal/plugin/llm/coze/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ func (API) Models() []plugin.Model {
}
}

func (API) HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error) {
var (
completion = common.GetGinCompletion(ctx)
toolMessages = common.FindToolMessages(&completion)
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)
return
}

func (API) Completion(ctx *gin.Context) {
var (
cookie = ctx.GetString("token")
Expand Down
10 changes: 2 additions & 8 deletions internal/plugin/llm/coze/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,10 @@ label:

func mergeMessages(ctx *gin.Context) (newMessages []coze.Message, tokens int, err error) {
var (
completion = common.GetGinCompletion(ctx)
messages = completion.Messages
toolMessages = common.FindToolMessages(&completion)
completion = common.GetGinCompletion(ctx)
messages = completion.Messages
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)

var (
pos = 0
contents []string
Expand Down
14 changes: 14 additions & 0 deletions internal/plugin/llm/lmsys/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"chatgpt-adapter/internal/plugin"
"chatgpt-adapter/internal/vars"
"chatgpt-adapter/logger"
"chatgpt-adapter/pkg"
"errors"
"github.com/gin-gonic/gin"
"strings"
Expand Down Expand Up @@ -118,6 +119,19 @@ func (API) Models() (result []plugin.Model) {
return
}

func (API) HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error) {
var (
completion = common.GetGinCompletion(ctx)
toolMessages = common.FindToolMessages(&completion)
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)
return
}

func (API) Completion(ctx *gin.Context) {
var (
token = ctx.GetString("token")
Expand Down
8 changes: 1 addition & 7 deletions internal/plugin/llm/lmsys/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,9 @@ label:

func mergeMessages(ctx *gin.Context, completion pkg.ChatCompletion) (newMessages string, err error) {
var (
messages []pkg.Keyv[interface{}]
toolMessages = common.FindToolMessages(&completion)
messages = completion.Messages
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)

var (
pos = 0
contents []string
Expand Down
137 changes: 52 additions & 85 deletions internal/plugin/llm/you/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,93 +128,51 @@ func (API) Match(_ *gin.Context, model string) bool {
return false
}

func (API) Models() []plugin.Model {
return []plugin.Model{
{
Id: "you/" + you.GPT_4,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GPT_4o,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GPT_4o_MINI,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GPT_4_TURBO,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.OPENAI_O1,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.OPENAI_O1_MINI,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.CLAUDE_2,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.CLAUDE_3_HAIKU,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.CLAUDE_3_SONNET,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.CLAUDE_3_5_SONNET,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.CLAUDE_3_OPUS,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GEMINI_1_0_PRO,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GEMINI_1_5_PRO,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
},
{
Id: "you/" + you.GEMINI_1_5_FLASH,
func (API) Models() (slice []plugin.Model) {
for _, model := range []string{
you.GPT_4,
you.GPT_4_TURBO,
you.GPT_4o,
you.GPT_4o_MINI,
you.OPENAI_O1,
you.OPENAI_O1_MINI,
you.CLAUDE_2,
you.CLAUDE_3_HAIKU,
you.CLAUDE_3_SONNET,
you.CLAUDE_3_5_SONNET,
you.CLAUDE_3_OPUS,
you.GEMINI_1_0_PRO,
you.GEMINI_1_5_PRO,
you.GEMINI_1_5_FLASH,
} {
slice = append(slice, plugin.Model{
Id: "you/" + model,
Object: "model",
Created: 1686935002,
By: Model + "-adapter",
})
}
return
}

func (API) HandleMessages(ctx *gin.Context) (messages []pkg.Keyv[interface{}], err error) {
var (
completion = common.GetGinCompletion(ctx)
toolMessages = common.FindToolMessages(&completion)
)

if messages, err = common.HandleMessages(completion, &vars.Config{
Settings: &vars.ConfigSettings{
StripAssistant: true,
PromptExperiments: true,
PassParams: true,
XmlPlot: true,
},
}); err != nil {
return
}
messages = append(messages, toolMessages...)
return
}

func (API) Completion(ctx *gin.Context) {
Expand Down Expand Up @@ -275,17 +233,26 @@ label:
chat.LimitWithE(true)
chat.Client(plugin.HTTPClient)

//if err = tryCloudFlare(); err != nil {
// if err = tryCloudFlare(); err != nil {
// response.Error(ctx, -1, err)
// return
//}
// }

chat.CloudFlare(clearance, userAgent, lang)

var cancel chan error
cancel, matchers = joinMatchers(ctx, matchers)
ctx.Set(ginTokens, tokens)

if pkg.Config.GetBool("you.custom") {
err = chat.Custom(common.GetGinContext(ctx), "custom-"+completion.Model, "", false)
if err != nil {
logger.Error(err)
response.Error(ctx, -1, err)
return
}
}

ch, err := chat.Reply(common.GetGinContext(ctx), nil, fileMessage, message)
if err != nil {
logger.Error(err)
Expand Down Expand Up @@ -496,7 +463,7 @@ func Condition(cookie string) bool {
return false
}

//return true
// return true
chat := you.New(cookie, you.CLAUDE_2, vars.Proxies)
chat.Client(plugin.HTTPClient)
chat.CloudFlare(clearance, userAgent, lang)
Expand Down
8 changes: 1 addition & 7 deletions internal/plugin/llm/you/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,9 @@ label:
func mergeMessages(ctx *gin.Context, completion pkg.ChatCompletion) (fileMessage string, text string, tokens int, err error) {
text = notice
var (
messages = completion.Messages
toolMessages = common.FindToolMessages(&completion)
messages = completion.Messages
)

if messages, err = common.HandleMessages(completion, nil); err != nil {
return
}
messages = append(messages, toolMessages...)

var (
pos = 0
contents []string
Expand Down

0 comments on commit 0a44bd3

Please sign in to comment.