1
- from typing import AsyncGenerator , AsyncIterator
1
+ from typing import AsyncGenerator , AsyncIterator , Optional
2
2
3
3
import sentry_sdk
4
4
from loguru import logger
5
5
6
6
from langchain_core .messages .base import BaseMessage as LangchainBaseMessage
7
7
from langchain_core .prompts import ChatPromptTemplate
8
+ from langchain_core .runnables .base import Runnable
8
9
from langchain .chat_models import init_chat_model
9
10
10
11
from vocode .streaming .action .abstract_factory import AbstractActionFactory
11
12
from vocode .streaming .action .default_factory import DefaultActionFactory
12
13
from vocode .streaming .agent .anthropic_utils import merge_bot_messages_for_langchain
13
14
from vocode .streaming .agent .base_agent import GeneratedResponse , RespondAgent , StreamedResponse
14
15
from vocode .streaming .agent .streaming_utils import collate_response_async , stream_response_async
15
- from vocode .streaming .models .actions import FunctionFragment
16
16
from vocode .streaming .models .agent import LangchainAgentConfig
17
17
from vocode .streaming .models .events import Sender
18
18
from vocode .streaming .models .message import BaseMessage , LLMToken
@@ -26,21 +26,26 @@ def __init__(
26
26
self ,
27
27
agent_config : LangchainAgentConfig ,
28
28
action_factory : AbstractActionFactory = DefaultActionFactory (),
29
+ chain : Optional [Runnable ] = None ,
29
30
** kwargs ,
30
31
):
31
32
super ().__init__ (
32
33
agent_config = agent_config ,
33
34
action_factory = action_factory ,
34
35
** kwargs ,
35
36
)
36
- self .model = init_chat_model (model = self .agent_config .model_name , model_provider = self .agent_config .provider , temperature = self .agent_config .temperature , max_tokens = self .agent_config .max_tokens )
37
+ self .chain = chain if chain else self .create_chain ()
38
+
39
+ def create_chain (self ):
40
+ model = init_chat_model (model = self .agent_config .model_name , model_provider = self .agent_config .provider , temperature = self .agent_config .temperature , max_tokens = self .agent_config .max_tokens )
37
41
messages_for_prompt_template = [
38
42
("placeholder" , "{chat_history}" )
39
43
]
40
44
if self .agent_config .prompt_preamble :
41
45
messages_for_prompt_template .insert (0 , ("system" , self .agent_config .prompt_preamble ))
42
- self .prompt_template = ChatPromptTemplate .from_messages (messages_for_prompt_template )
43
- self .chain = self .prompt_template | self .model
46
+ prompt_template = ChatPromptTemplate .from_messages (messages_for_prompt_template )
47
+ chain = prompt_template | model
48
+ return chain
44
49
45
50
async def token_generator (
46
51
self ,
0 commit comments