forked from geekr-dev/chatgpt-client
-
Notifications
You must be signed in to change notification settings - Fork 1
/
chat.go
139 lines (122 loc) · 2.85 KB
/
chat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package main
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"github.com/charmbracelet/glamour"
"github.com/common-nighthawk/go-figure"
gpt3 "github.com/sashabaranov/go-gpt3"
)
func main() {
// 获取 OpenAI API Key
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
fmt.Println("请设置 OPENAI_API_KEY 环境变量")
return
}
// 初始化 Glamour 渲染器
renderStyle := glamour.WithEnvironmentConfig()
mdRenderer, err := glamour.NewTermRenderer(
renderStyle,
)
if err != nil {
fmt.Println("初始化 Markdown 渲染器失败")
return
}
// 输出欢迎语
myFigure := figure.NewFigure("ChatGPT", "", true)
myFigure.Print()
fmt.Println()
// 创建 ChatGPT 客户端
client := gpt3.NewClient(apiKey)
if err != nil {
fmt.Printf("创建客户端失败: %s\n", err.Error())
return
}
messages := []gpt3.ChatCompletionMessage{
{
Role: "system",
Content: "你是ChatGPT, OpenAI训练的大型语言模型, 请尽可能简洁地回答我的问题",
},
}
for {
fmt.Print("👽 ")
// 读取用户输入并交互
inputReader := bufio.NewReader(os.Stdin)
userInput, err := inputReader.ReadString('\n')
if err != nil {
fmt.Println(err)
continue
}
if userInput == "" || userInput == "\n" {
continue
}
if strings.HasSuffix(userInput, "\\c\n") {
// 数组还原
messages = []gpt3.ChatCompletionMessage{
{
Role: "system",
Content: "你是ChatGPT, OpenAI训练的大型语言模型, 请尽可能简洁地回答我的问题",
},
}
fmt.Println("会话已重置")
continue
}
messages = append(
messages, gpt3.ChatCompletionMessage{
Role: "user",
Content: userInput,
},
)
if len(messages) > 4096 {
// 数组还原
messages = []gpt3.ChatCompletionMessage{
{
Role: "system",
Content: "你是ChatGPT, OpenAI训练的大型语言模型, 请尽可能简洁地回答我的问题",
},
}
fmt.Println("会话已重置")
// 重新添加消息
messages = append(
messages, gpt3.ChatCompletionMessage{
Role: "user",
Content: userInput,
},
)
}
// 调用 ChatGPT API 接口生成回答
resp, err := client.CreateChatCompletion(
context.Background(),
gpt3.ChatCompletionRequest{
Model: gpt3.GPT3Dot5Turbo,
Messages: messages,
MaxTokens: 1024,
Temperature: 0,
N: 1,
},
)
if err != nil {
fmt.Printf("ChatGPT 接口调用失败: %s\n", err.Error())
userInput = ""
continue
}
// 格式化输出结果
output := resp.Choices[0].Message.Content
mdOutput, err := mdRenderer.Render(output)
if err != nil {
fmt.Printf("Markdown 渲染失败: %s\n", err.Error())
userInput = ""
continue
}
fmt.Println("🤖 " + mdOutput)
messages = append(
messages, gpt3.ChatCompletionMessage{
Role: "assistant",
Content: output,
},
)
}
}