-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_chat.py
73 lines (65 loc) · 2.35 KB
/
base_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import os
current_path = os.path.dirname(__file__)
token_cost_path = os.path.join(current_path, 'token_cost.json')
if not os.path.exists(token_cost_path):
with open(token_cost_path, 'w', encoding='utf-8') as f:
f.write('{}')
class aichat:
headers: dict # 请求头
data: dict # 请求数据
response: str # 响应数据
usage: dict # 消耗tokens,包含下面三个
completion_tokens: int # 补全tokens
prompt_tokens: int # 提示词tokens
total_tokens: int = 0 # 总tokens
def __init__(self):
pass
async def asend(self, msg, gid, uid) -> dict:
'''异步,向AI提问'''
pass
async def token_cost_record(self, gid, uid, cost, api):
'''记录token消耗
Args:
gid (int): 群号
uid (int): 用户号(QQ号)
cost (int): 消耗的tokens数
api (str): 调用的api名称
'''
gid = str(gid)
uid = str(uid)
with open(token_cost_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if gid not in data:
data[gid] = {}
if api not in data[gid]:
data[gid][api] = {}
if uid not in data[gid][api]:
data[gid][api][uid] = 0
data[gid][api][uid] += cost
data[gid][api]['total'] = sum(value for key, value in data[gid][api].items() if key != 'total')
with open(token_cost_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
async def token_cost_record_new(self, gid, uid, usage, api):
'''直接传入usage数组,记录token消耗
Args:
usage (dict): api返回的usage数组
api (str): 调用的api名称
'''
gid = str(gid)
uid = str(uid)
try:
cost = int(usage['total_tokens'])
except:
cost = int(usage['completion_tokens']) + int(usage['prompt_tokens'])
self.total_tokens = cost
await self.token_cost_record(gid, uid, cost, api)
def get_response(self):
'''获取AI响应'''
return self.response.strip()
def get_usage(self):
'''获取本次调用消耗的总tokens'''
return self.total_tokens
if __name__ == '__main__':
a = aichat()
a.token_cost_record(123456, 123456789, 100, 'api')