diff --git a/chatgpt/chatgpt.go b/chatgpt/chatgpt.go index c4ae7eb..208653e 100644 --- a/chatgpt/chatgpt.go +++ b/chatgpt/chatgpt.go @@ -5,22 +5,33 @@ import ( gpt35 "github.com/poorjobless/wechatbot/gpt-client" ) -func Completions(msg string) (string, error) { +var Cache CacheInterface + +func init() { + Cache = GetSessionCache() +} + +func Completions(session, msg string) (string, error) { + ms := Cache.GetMsg(session) + message := &gpt35.Message{ + Role: "user", + Content: msg, + } + ms = append(ms, *message) c := gpt35.NewClient(config.LoadConfig().ApiKey, config.LoadConfig().Proxy) req := &gpt35.Request{ - Model: gpt35.ModelGpt35Turbo, - Messages: []*gpt35.Message{ - { - Role: gpt35.RoleUser, - Content: msg, - }, - }, + Model: gpt35.ModelGpt35Turbo, + Messages: ms, + MaxTokens: 1000, + Temperature: 0.7, } resp, err := c.GetChat(req) if err != nil { panic(err) } + ms = append(ms, resp.Choices[0].Message) + Cache.SetMsg(session, ms) return resp.Choices[0].Message.Content, err } diff --git a/chatgpt/sessionCache.go b/chatgpt/sessionCache.go new file mode 100644 index 0000000..f63c7b9 --- /dev/null +++ b/chatgpt/sessionCache.go @@ -0,0 +1,74 @@ +package chatgpt + +import ( + "encoding/json" + "time" + + "github.com/patrickmn/go-cache" + + gpt35 "github.com/poorjobless/wechatbot/gpt-client" +) + +type SessionService struct { + cache *cache.Cache +} + +type SessionMeta struct { + Msg []gpt35.Message `json:"msg,omitempty"` +} + +type CacheInterface interface { + GetMsg(sessionId string) []gpt35.Message + SetMsg(sessionId string, msg []gpt35.Message) + Clear(sessionId string) +} + +var sessionServices *SessionService + +func getLength(strPool []gpt35.Message) int { + var total int + for _, v := range strPool { + bytes, _ := json.Marshal(v) + total += len(string(bytes)) + } + return total +} + +func (s *SessionService) GetMsg(sessionId string) (msg []gpt35.Message) { + sessionContext, ok := s.cache.Get(sessionId) + if !ok { + return nil + } + sessionMeta := sessionContext.(*SessionMeta) + return sessionMeta.Msg +} + +func (s *SessionService) SetMsg(sessionId string, msg []gpt35.Message) { + maxLength := 4096 + maxCacheTime := time.Hour * 12 + + for getLength(msg) > maxLength { + msg = append(msg[:1], msg[2:]...) + } + + sessionContext, ok := s.cache.Get(sessionId) + if !ok { + sessionMeta := &SessionMeta{Msg: msg} + s.cache.Set(sessionId, sessionMeta, maxCacheTime) + return + } + sessionMeta := sessionContext.(*SessionMeta) + sessionMeta.Msg = msg + s.cache.Set(sessionId, sessionMeta, maxCacheTime) +} + +func (s *SessionService) Clear(sessionId string) { + s.cache.Delete(sessionId) +} + +func GetSessionCache() CacheInterface { + if sessionServices == nil { + sessionServices = &SessionService{cache: cache.New(time.Hour*12, time.Hour*1)} + } + return sessionServices +} diff --git a/go.mod b/go.mod index c09707f..9a8fb80 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/poorjobless/wechatbot go 1.16 -require github.com/eatmoreapple/openwechat v1.4.1 +require ( + github.com/eatmoreapple/openwechat v1.4.1 + github.com/patrickmn/go-cache v2.1.0+incompatible +) diff --git a/go.sum b/go.sum index bdaee9e..11a93f2 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/eatmoreapple/openwechat v1.4.1 h1:hIVEr2Xaj+r1SXzdTigqhIXiuu6TZd+NPWdEVVt/qeM= github.com/eatmoreapple/openwechat v1.4.1/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= diff --git a/gpt-client/types.go b/gpt-client/types.go index 2af435e..abe7575 100644 --- a/gpt-client/types.go +++ b/gpt-client/types.go @@ -4,38 +4,37 @@ type RoleType string type ModelType string type Request struct { - Model ModelType `json:"model"` - Messages []*Message `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop interface{} `json:"stop,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - LogitBias interface{} `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` } type Response struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []*Choice `json:"choices"` - Usage *Usage `json:"usage"` - Error *Error `json:"error,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokes int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } } type Message struct { - Role RoleType `json:"role,omitempty"` + Role RoleType `json:"role"` Content string `json:"content"` } type Choice struct { - Index int `json:"index"` - Message *Message `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` } type Usage struct { diff --git a/handlers/group_msg_handler.go b/handlers/group_msg_handler.go index 7f0dbf7..5f6014e 100644 --- a/handlers/group_msg_handler.go +++ b/handlers/group_msg_handler.go @@ -40,10 +40,15 @@ func (g *GroupMessageHandler) ReplyText(msg *openwechat.Message) error { return nil } + if strings.Contains(msg.Content, "clear") { + chatgpt.Cache.Clear(group.ID()) + log.Println("clear", group.ID()) + return nil + } // 替换掉@文本,然后向GPT发起请求 replaceText := "@" + sender.NickName requestText := strings.TrimSpace(strings.ReplaceAll(msg.Content, replaceText, "")) - reply, err := chatgpt.Completions(requestText) + reply, err := chatgpt.Completions(group.ID(), requestText) if err != nil { log.Printf("chatgpt request error: %v \n", err) msg.ReplyText("机器人神了,我一会发现了就去修。") @@ -63,7 +68,7 @@ func (g *GroupMessageHandler) ReplyText(msg *openwechat.Message) error { // 回复@我的用户 reply = strings.TrimSpace(reply) reply = strings.Trim(reply, "\n") - atText := "@" + groupSender.NickName + ": " + atText := "@" + groupSender.NickName + " " replyText := atText + reply _, err = msg.ReplyText(replyText) if err != nil { diff --git a/handlers/user_msg_handler.go b/handlers/user_msg_handler.go index 06f7b0b..e9d5bad 100644 --- a/handlers/user_msg_handler.go +++ b/handlers/user_msg_handler.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/eatmoreapple/openwechat" + "github.com/poorjobless/wechatbot/chatgpt" ) @@ -33,10 +34,16 @@ func (g *UserMessageHandler) ReplyText(msg *openwechat.Message) error { sender, err := msg.Sender() log.Printf("Received User %v Text Msg : %v", sender.NickName, msg.Content) + if strings.Contains(msg.Content, "clear") { + chatgpt.Cache.Clear(sender.ID()) + log.Println("clear", sender.ID()) + return nil + } + // 向GPT发起请求 requestText := strings.TrimSpace(msg.Content) requestText = strings.Trim(msg.Content, "\n") - reply, err := chatgpt.Completions(requestText) + reply, err := chatgpt.Completions(sender.ID(), requestText) if err != nil { log.Printf("chatgpt request error: %v \n", err) msg.ReplyText("机器人神了,我一会发现了就去修。")