-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagentclass.py
146 lines (121 loc) · 4.09 KB
/
agentclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import List, Dict, Callable
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
BaseMessage,
)
from langchain.agents import Tool
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.agents import load_tools
import streamlit as st
from langchain.callbacks import StreamlitCallbackHandler
# @title DialogueAgent and DialogueSimulator classes
class DialogueAgent:
def __init__(
self,
name: str,
system_message: SystemMessage,
model: ChatOpenAI,
) -> None:
self.name = name
self.system_message = system_message
self.model = model
self.prefix = f"{self.name}: "
self.reset()
def reset(self):
self.message_history = ["Here is the conversation so far."]
def send(self) -> str:
"""
Applies the chatmodel to the message history
and returns the message string
"""
message = self.model(
[
self.system_message,
HumanMessage(content="\n".join(self.message_history + [self.prefix])),
]
)
return message.content
def receive(self, name: str, message: str) -> None:
"""
Concatenates {message} spoken by {name} into message history
"""
self.message_history.append(f"{name}: {message}")
class DialogueSimulator:
def __init__(
self,
agents: List[DialogueAgent],
selection_function: Callable[[int, List[DialogueAgent]], int],
) -> None:
self.agents = agents
self._step = 0
self.select_next_speaker = selection_function
def reset(self):
for agent in self.agents:
agent.reset()
def inject(self, name: str, message: str):
"""
Initiates the conversation with a {message} from {name}
"""
for agent in self.agents:
agent.receive(name, message)
# increment time
self._step += 1
def step(self) -> tuple[str, str]:
# 1. choose the next speaker
speaker_idx = self.select_next_speaker(self._step, self.agents)
speaker = self.agents[speaker_idx]
# 2. next speaker sends message
message = speaker.send()
# 3. everyone receives message
for receiver in self.agents:
receiver.receive(speaker.name, message)
# 4. increment time
self._step += 1
return speaker.name, message
# @title DialogueAgentWithTools class
class DialogueAgentWithTools(DialogueAgent):
def __init__(
self,
name: str,
system_message: SystemMessage,
model: ChatOpenAI,
tool_names: List[str],
**tool_kwargs,
) -> None:
super().__init__(name, system_message, model)
self.tools = load_tools(tool_names, **tool_kwargs)
def send(self) -> str:
"""
Applies the chatmodel to the message history
and returns the message string
"""
agent_chain = initialize_agent(
self.tools,
self.model,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
verbose=True,
memory=ConversationBufferMemory(
memory_key="chat_history", return_messages=True
),
)
prompt = "\n".join(
[self.system_message.content] + self.message_history + [self.prefix]
)
st_callback2 = StreamlitCallbackHandler(st.container())
message = AIMessage(
content=agent_chain.run(input=prompt, callbacks=[st_callback2])
)
# with st.chat_message("assistant"):
# st_callback = StreamlitCallbackHandler(st.container())
# source_response = agent_chain.run(prompt, callbacks=[st_callback])
# st.write(response)
# st.write(message)
return message.content