forked from TheExplainthis/ChatGPT-Line-Bot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
313 lines (286 loc) · 13.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from dotenv import load_dotenv
load_dotenv('.env')
from flask import Flask, request, abort
from linebot import (
LineBotApi, WebhookHandler
)
from linebot.exceptions import (
InvalidSignatureError
)
from linebot.models import (
MessageEvent, TextMessage, TextSendMessage, ImageSendMessage, AudioMessage
)
import os
import uuid
from src.models import OpenAIModel
from src.memory import Memory
from src.logger import logger
from src.storage import Storage, FileStorage, MongoStorage
from src.utils import get_role_and_content
from src.service.youtube import Youtube, YoutubeTranscriptReader
from src.service.bilibili import Bilibili, BilibiliTranscriptReader
from src.service.pdf import PDF, PDFQA
from src.service.website import Website, WebsiteReader
from src.mongodb import mongodb
from waitress import serve
import tempfile
import requests
import re
app = Flask(__name__)
line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN'))
handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET'))
storage = None
youtube = Youtube(step=4)
bilibili = Bilibili(step=2)
pdf = PDF()
website = Website()
memory = Memory(system_message=os.getenv('SYSTEM_MESSAGE'), memory_message_count=2)
model_management = {}
pdfqa_management = {}
api_keys = {}
@app.route("/callback", methods=['POST'])
def callback():
signature = request.headers['X-Line-Signature']
body = request.get_data(as_text=True)
app.logger.info("Request body: " + body)
try:
handler.handle(body, signature)
except InvalidSignatureError:
print("Invalid signature. Please check your channel access token/channel secret.")
abort(400)
return 'OK'
@handler.add(MessageEvent, message=TextMessage)
def handle_text_message(event):
user_id = event.source.user_id
try:
group_id = event.source.group_id
except:
group_id = None
text = event.message.text.strip()
logger.info(f'{user_id}: {text}' + (f' (from group {group_id})' if group_id else ''))
try:
msg = None
if text.startswith('/Reg '):
api_key = text[4:].strip()
model = OpenAIModel(api_key=api_key)
is_successful, _, _ = model.check_token_valid()
if not is_successful:
raise ValueError('Invalid API token')
model_management[user_id] = model
pdfqa_management[user_id] = PDFQA(openai_api_key=api_key)
storage.save({
user_id: api_key
})
msg = TextSendMessage(text='Token 有效,註冊成功')
elif text.startswith('/RegGroup'):
if group_id is None:
msg = TextSendMessage(text='该命令仅在群组中有效')
else:
model_management[group_id] = model_management[user_id]
pdfqa_management[group_id] = PDFQA(openai_api_key=pdfqa_management[user_id].openai_api_key)
msg = TextSendMessage(text='用户具有有效 token,群组註冊成功')
elif text.startswith('/Help'):
msg = TextSendMessage(text=
"指令:\n" +
"/Reg + API Token\n👉 API Token 請先到 https://platform.openai.com/ 註冊登入後取得\n\n" +
"/RegGroup\n👉 已注册的用户可以为其所在的群组注册,注册后群组中的人共用同一个 API Token 以及历史信息\n\n" +
"/SysMsg + Prompt\n👉 Prompt 可以命令機器人扮演某個角色,例如:請你扮演擅長做總結的人\n\n" +
"/History\n👉 打印当前对话中存储的历史内容\n\n" +
"/Clear\n👉 這個指令能夠清除歷史訊息\n\n" +
"/Image + Prompt\n👉 會調用 DALL∙E 2 Model,以文字生成圖像\n\n" +
"/Chat + Prompt\n👉 調用 ChatGPT 以文字回覆\n\n" +
"語音輸入\n👉 會調用 Whisper 模型,先將語音轉換成文字,再調用 ChatGPT 以文字回覆"
)
elif text.startswith('/SysMsg'):
if group_id is None:
memory.change_system_message(user_id, text[7:].strip())
else:
memory.change_system_message(group_id, text[7:].strip())
msg = TextSendMessage(text='輸入成功')
elif text.startswith('/History'):
history = memory.get(user_id) if group_id is None else memory.get(group_id)
msg = TextSendMessage(text=f'对话历史:\n{history}')
elif text.startswith('/Clear'):
if group_id is None:
memory.remove(user_id)
pdfqa_management[user_id] = PDFQA(openai_api_key=pdfqa_management[user_id].openai_api_key)
else:
memory.remove(group_id)
pdfqa_management[group_id] = PDFQA(openai_api_key=pdfqa_management[group_id].openai_api_key)
msg = TextSendMessage(text='歷史訊息清除成功')
elif text.startswith('/Image'):
prompt = text[6:].strip()
if group_id is None:
memory.append(user_id, 'user', prompt)
is_successful, response, error_message = model_management[user_id].image_generations(prompt)
else:
memory.append(group_id, 'user', prompt)
is_successful, response, error_message = model_management[group_id].image_generations(prompt)
if not is_successful:
raise Exception(error_message)
url = response['data'][0]['url']
msg = ImageSendMessage(
original_content_url=url,
preview_image_url=url
)
if group_id is None:
memory.append(user_id, 'assistant', url)
else:
memory.append(group_id, 'assistant', url)
elif text.startswith('/Chat '):
text = text[5:].strip()
if group_id is not None:
user_model = model_management[group_id]
memory.append(group_id, 'user', text)
else:
user_model = model_management[user_id]
memory.append(user_id, 'user', text)
url = website.get_url_from_text(text)
if url:
if youtube.retrieve_video_id(text):
is_successful, chunks, error_message = youtube.get_transcript_chunks(youtube.retrieve_video_id(text))
if not is_successful:
raise Exception(error_message)
youtube_transcript_reader = YoutubeTranscriptReader(user_model, os.getenv('OPENAI_MODEL_ENGINE'))
is_successful, response, error_message = youtube_transcript_reader.summarize(chunks)
if not is_successful:
raise Exception(error_message)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
elif bilibili.retrieve_video_id(text):
is_successful, chunks, error_message = bilibili.get_transcript_chunks(bilibili.retrieve_video_id(text))
if not is_successful:
raise Exception(error_message)
bilibili_transcript_reader = BilibiliTranscriptReader(user_model, os.getenv('OPENAI_MODEL_ENGINE'))
is_successful, response, error_message = bilibili_transcript_reader.summarize(chunks)
if not is_successful:
raise Exception(error_message)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
else:
chunks = website.get_content_from_url(url)
if len(chunks) == 0:
raise Exception('無法撈取此網站文字')
website_reader = WebsiteReader(user_model, os.getenv('OPENAI_MODEL_ENGINE'))
is_successful, response, error_message = website_reader.summarize(chunks)
if not is_successful:
raise Exception(error_message)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
else:
history = memory.get(user_id) if group_id is None else memory.get(group_id)
is_successful, response, error_message = user_model.chat_completions(history, os.getenv('OPENAI_MODEL_ENGINE'))
if not is_successful:
raise Exception(error_message)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
if group_id is None:
memory.append(user_id, role, response)
else:
memory.append(group_id, role, response)
elif text.startswith('/ChatPDF '):
# raise Exception("Not yet implemented.")
text = text[8:].strip()
if group_id is not None:
user_model = model_management[group_id]
user_pdfqa = pdfqa_management[group_id]
memory.append(group_id, 'user', text)
else:
user_model = model_management[user_id]
user_pdfqa = pdfqa_management[user_id]
memory.append(user_id, 'user', text)
pdf_link = pdf.get_pdf_link(text)
if pdf_link:
response = requests.get(pdf_link)
if not response.headers['content-type'].endswith("pdf"):
raise Exception("The PDF file cannot be downloaded.")
if "Content-Disposition" in response.headers.keys():
pdf_fname = re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
else:
pdf_fname = os.path.basename(pdf_link)
if not pdf_fname.endswith(".pdf"):
pdf_fname += ".pdf"
pdf_path = os.path.join(pdf_dir, pdf_fname)
with open(pdf_path, "wb+") as pdf_file:
pdf_file.write(response.content)
try:
user_pdfqa.add(pdf_path)
except ValueError as e:
raise Exception(str(e))
msg = TextSendMessage(text=f"The PDF file is loaded. Now {len(user_pdfqa.docs)} PDF in the collection.")
else:
if not len(user_pdfqa.docs):
msg = TextSendMessage(text="Please load a PDF file first")
else:
try:
ans = user_pdfqa.query(text)
except Exception as e:
raise Exception(str(e))
msg = TextSendMessage(text=ans.formatted_answer)
except ValueError:
msg = TextSendMessage(text='Token 無效,請重新註冊,格式為 /Reg sk-xxxxx')
except KeyError:
msg = TextSendMessage(text='請先註冊 Token,格式為 /Reg sk-xxxxx')
except Exception as e:
memory.remove(user_id)
if str(e).startswith('Incorrect API key provided'):
msg = TextSendMessage(text='OpenAI API Token 有誤,請重新註冊。')
elif str(e).startswith('That model is currently overloaded with other requests.'):
msg = TextSendMessage(text='已超過負荷,請稍後再試')
else:
msg = TextSendMessage(text=str(e))
if msg is not None:
line_bot_api.reply_message(event.reply_token, msg)
@handler.add(MessageEvent, message=AudioMessage)
def handle_audio_message(event):
user_id = event.source.user_id
audio_content = line_bot_api.get_message_content(event.message.id)
input_audio_path = f'{str(uuid.uuid4())}.m4a'
with open(input_audio_path, 'wb') as fd:
for chunk in audio_content.iter_content():
fd.write(chunk)
try:
if not model_management.get(user_id):
raise ValueError('Invalid API token')
else:
is_successful, response, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1')
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', response['text'])
is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo')
if not is_successful:
raise Exception(error_message)
role, response = get_role_and_content(response)
memory.append(user_id, role, response)
msg = TextSendMessage(text=response)
except ValueError:
msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /Reg [API TOKEN]')
except KeyError:
msg = TextSendMessage(text='請先註冊 Token,格式為 /Reg sk-xxxxx')
except Exception as e:
memory.remove(user_id)
if str(e).startswith('Incorrect API key provided'):
msg = TextSendMessage(text='OpenAI API Token 有誤,請重新註冊。')
else:
msg = TextSendMessage(text=str(e))
os.remove(input_audio_path)
line_bot_api.reply_message(event.reply_token, msg)
@app.route("/", methods=['GET'])
def home():
return 'Hello World'
if __name__ == "__main__":
if os.getenv('USE_MONGO'):
mongodb.connect_to_database()
storage = Storage(MongoStorage(mongodb.db))
else:
storage = Storage(FileStorage('db.json'))
try:
data = storage.load()
for user_id in data.keys():
model_management[user_id] = OpenAIModel(api_key=data[user_id])
pdfqa_management[user_id] = PDFQA(openai_api_key=data[user_id])
except FileNotFoundError:
pass
# app.run(host='0.0.0.0', port=8080)
with tempfile.TemporaryDirectory() as pdf_dir:
serve(app, host="0.0.0.0", port=8080)