Skip to content

Commit

Permalink
support context
Browse files Browse the repository at this point in the history
  • Loading branch information
CsterKuroi committed Mar 28, 2023
1 parent bcead82 commit f127fc0
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 34 deletions.
27 changes: 19 additions & 8 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,33 @@ import (
gpt35 "github.com/poorjobless/wechatbot/gpt-client"
)

func Completions(msg string) (string, error) {
var Cache CacheInterface

func init() {
Cache = GetSessionCache()
}

func Completions(session, msg string) (string, error) {
ms := Cache.GetMsg(session)
message := &gpt35.Message{
Role: "user",
Content: msg,
}
ms = append(ms, *message)
c := gpt35.NewClient(config.LoadConfig().ApiKey, config.LoadConfig().Proxy)
req := &gpt35.Request{
Model: gpt35.ModelGpt35Turbo,
Messages: []*gpt35.Message{
{
Role: gpt35.RoleUser,
Content: msg,
},
},
Model: gpt35.ModelGpt35Turbo,
Messages: ms,
MaxTokens: 1000,
Temperature: 0.7,
}

resp, err := c.GetChat(req)
if err != nil {
panic(err)
}
ms = append(ms, resp.Choices[0].Message)
Cache.SetMsg(session, ms)

return resp.Choices[0].Message.Content, err
}
74 changes: 74 additions & 0 deletions chatgpt/sessionCache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package chatgpt

import (
"encoding/json"
"time"

"github.com/patrickmn/go-cache"

gpt35 "github.com/poorjobless/wechatbot/gpt-client"
)

type SessionService struct {
cache *cache.Cache
}

type SessionMeta struct {
Msg []gpt35.Message `json:"msg,omitempty"`
}

type CacheInterface interface {
GetMsg(sessionId string) []gpt35.Message
SetMsg(sessionId string, msg []gpt35.Message)
Clear(sessionId string)
}

var sessionServices *SessionService

func getLength(strPool []gpt35.Message) int {
var total int
for _, v := range strPool {
bytes, _ := json.Marshal(v)
total += len(string(bytes))
}
return total
}

func (s *SessionService) GetMsg(sessionId string) (msg []gpt35.Message) {
sessionContext, ok := s.cache.Get(sessionId)
if !ok {
return nil
}
sessionMeta := sessionContext.(*SessionMeta)
return sessionMeta.Msg
}

func (s *SessionService) SetMsg(sessionId string, msg []gpt35.Message) {
maxLength := 4096
maxCacheTime := time.Hour * 12

for getLength(msg) > maxLength {
msg = append(msg[:1], msg[2:]...)
}

sessionContext, ok := s.cache.Get(sessionId)
if !ok {
sessionMeta := &SessionMeta{Msg: msg}
s.cache.Set(sessionId, sessionMeta, maxCacheTime)
return
}
sessionMeta := sessionContext.(*SessionMeta)
sessionMeta.Msg = msg
s.cache.Set(sessionId, sessionMeta, maxCacheTime)
}

func (s *SessionService) Clear(sessionId string) {
s.cache.Delete(sessionId)
}

func GetSessionCache() CacheInterface {
if sessionServices == nil {
sessionServices = &SessionService{cache: cache.New(time.Hour*12, time.Hour*1)}
}
return sessionServices
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ module github.com/poorjobless/wechatbot

go 1.16

require github.com/eatmoreapple/openwechat v1.4.1
require (
github.com/eatmoreapple/openwechat v1.4.1
github.com/patrickmn/go-cache v2.1.0+incompatible
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
github.com/eatmoreapple/openwechat v1.4.1 h1:hIVEr2Xaj+r1SXzdTigqhIXiuu6TZd+NPWdEVVt/qeM=
github.com/eatmoreapple/openwechat v1.4.1/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
43 changes: 21 additions & 22 deletions gpt-client/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,37 @@ type RoleType string
type ModelType string

type Request struct {
Model ModelType `json:"model"`
Messages []*Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop interface{} `json:"stop,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias interface{} `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
}

type Response struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []*Choice `json:"choices"`
Usage *Usage `json:"usage"`
Error *Error `json:"error,omitempty"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokes int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
}

type Message struct {
Role RoleType `json:"role,omitempty"`
Role RoleType `json:"role"`
Content string `json:"content"`
}

type Choice struct {
Index int `json:"index"`
Message *Message `json:"message"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}

type Usage struct {
Expand Down
9 changes: 7 additions & 2 deletions handlers/group_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ func (g *GroupMessageHandler) ReplyText(msg *openwechat.Message) error {
return nil
}

if strings.Contains(msg.Content, "clear") {
chatgpt.Cache.Clear(group.ID())
log.Println("clear", group.ID())
return nil
}
// 替换掉@文本,然后向GPT发起请求
replaceText := "@" + sender.NickName
requestText := strings.TrimSpace(strings.ReplaceAll(msg.Content, replaceText, ""))
reply, err := chatgpt.Completions(requestText)
reply, err := chatgpt.Completions(group.ID(), requestText)
if err != nil {
log.Printf("chatgpt request error: %v \n", err)
msg.ReplyText("机器人神了,我一会发现了就去修。")
Expand All @@ -63,7 +68,7 @@ func (g *GroupMessageHandler) ReplyText(msg *openwechat.Message) error {
// 回复@我的用户
reply = strings.TrimSpace(reply)
reply = strings.Trim(reply, "\n")
atText := "@" + groupSender.NickName + ": "
atText := "@" + groupSender.NickName + " "
replyText := atText + reply
_, err = msg.ReplyText(replyText)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion handlers/user_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/eatmoreapple/openwechat"

"github.com/poorjobless/wechatbot/chatgpt"
)

Expand Down Expand Up @@ -33,10 +34,16 @@ func (g *UserMessageHandler) ReplyText(msg *openwechat.Message) error {
sender, err := msg.Sender()
log.Printf("Received User %v Text Msg : %v", sender.NickName, msg.Content)

if strings.Contains(msg.Content, "clear") {
chatgpt.Cache.Clear(sender.ID())
log.Println("clear", sender.ID())
return nil
}

// 向GPT发起请求
requestText := strings.TrimSpace(msg.Content)
requestText = strings.Trim(msg.Content, "\n")
reply, err := chatgpt.Completions(requestText)
reply, err := chatgpt.Completions(sender.ID(), requestText)
if err != nil {
log.Printf("chatgpt request error: %v \n", err)
msg.ReplyText("机器人神了,我一会发现了就去修。")
Expand Down

0 comments on commit f127fc0

Please sign in to comment.