-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsheepGPT.py
201 lines (168 loc) · 6.7 KB
/
sheepGPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import argparse
import difflib
import json
import os
import time
from typing import List, Optional
import ollama
import openai
import requests
from sympy.logic.boolalg import Boolean
from zeroconf_listener import listener
CACHE_FILE = "message_cache.json"
def parse_args():
parser = argparse.ArgumentParser(
description="Choose the model to use for response generation."
)
parser.add_argument(
"--model",
type=str,
default="ollama",
choices=["ollama", "gpt-4"],
help="The model to use for response generation.",
)
return parser.parse_args()
args = parse_args()
with open("system_prompt_baaahs.txt", "r") as file:
system_prompt = file.read()
print(f"Loaded system prompt: {system_prompt[:100]}...") # Print first 100 characters
if args.model == "ollama":
model = "llama3.1:70b"
print(f"Using Ollama with {model} model for response generation.")
else:
print("Using OpenAI GPT-4 for response generation.")
openai.api_key = "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"
# Define the maximum number of retries for failed operations
max_retries = 2
last_posted_thought = None
def load_cached_messages() -> List[str]:
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'r') as f:
return json.load(f)
return []
def save_cached_messages(messages: List[str]):
with open(CACHE_FILE, 'w') as f:
json.dump(messages, f)
def get_messages() -> Optional[List[str]]:
print("Getting messages...")
# Define the GET endpoint
get_endpoint = f"http://{listener.server_ip}:8080/messages"
for _ in range(max_retries):
try:
response = requests.get(get_endpoint)
messages = [msg["str"] for msg in response.json() if msg["type"] == "D"]
print(f"Got {len(messages)} messages")
cached_messages = load_cached_messages()
diff = difflib.ndiff(cached_messages, messages)
new_messages = [l[2:] for l in diff if l.startswith("+ ")]
save_cached_messages(messages)
return new_messages
except Exception as e:
print(f"Error getting messages: {e}")
return None
def stream_response(response, notify_thinking: bool = False):
if notify_thinking: notify_generating_thought(True)
total_response = ""
if args.model == "ollama":
for chunk in response:
delta_message = chunk['message']['content']
if delta_message:
print(delta_message, end="", flush=True)
total_response += delta_message
else:
for chunk in response:
for choice in chunk.choices:
delta_message = choice.delta.content
if delta_message:
print(delta_message, end="", flush=True)
total_response += delta_message
print()
if notify_thinking: notify_generating_thought(False)
return total_response
def is_question(message: str) -> bool:
# Check if the message ends with a question mark or starts with a question word
question_words = ['who', 'what', 'when', 'where', 'why', 'how', 'is', 'are', 'can', 'could', 'would', 'should',
'do', 'does', 'did']
return message.strip().endswith('?') or any(message.lower().strip().startswith(word) for word in question_words)
def determine_in_reply_to(messages: List[str]) -> str:
# Find the last question in the list of messages
for message in reversed(messages):
if is_question(message):
return message
# If no question is found, return the last message
return messages[-1] if messages else ""
def post_message(output: str, in_reply_to: str) -> bool:
global last_posted_thought
print("Posting new thought...")
post_endpoint = f"http://{listener.server_ip}:8080/newGPTReply"
if output == last_posted_thought:
print("Thought is the same as the last posted thought, skipping post.")
return False
for _ in range(max_retries):
try:
requests.post(post_endpoint, json={"answer": output, "inReplyTo": in_reply_to})
print("Thought posted")
last_posted_thought = output
return True
except Exception as e:
print(f"Error posting message: {e}")
return False
def notify_generating_thought(generating: bool) -> bool:
print(f"Notifying sheep is thinking={generating}...")
post_endpoint = f"http://{listener.server_ip}:8080/isGeneratingThought"
for _ in range(max_retries):
try:
requests.post(post_endpoint, json={"isGenerating": generating})
print("Sheep is thinking notification posted")
return True
except Exception as e:
print(f"Error posting sheep is thinking notification: {e}")
return False
def generate_response(messages: List[str]) -> Optional[str]:
print("Generating response...")
prompt = "\n".join(messages)
for _ in range(max_retries):
try:
if args.model == "ollama":
stream = ollama.chat(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
stream=True
)
total_response = stream_response(stream, notify_thinking=False)
print("Response generated")
return total_response
else:
response = openai.ChatCompletion.create(
model="gpt-4-1106-preview",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
temperature=0,
stream=True,
)
total_response = stream_response(response)
print("Response generated")
return total_response
except Exception as e:
print(f"Error generating response: {e}")
notify_generating_thought(False)
return None
# post_message(generate_response(["hello", "who are you?"]))
while True:
messages = get_messages()
if messages is not None:
if len(messages) > 0:
newmessages = "\n".join([msg for msg in messages])
print(f"{len(newmessages)} new messages:\n: {newmessages}")
response = generate_response(messages)
if response is not None:
in_reply_to = determine_in_reply_to(messages)
post_message(response, in_reply_to)
else:
print("No new messages, skipping response generation.")
time.sleep(5)