Skip to content

Commit 392c979

Browse files
committed
add chain input field and create_chain() method
1 parent 1b7bc87 commit 392c979

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vocode/streaming/agent/langchain_agent.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
from typing import AsyncGenerator, AsyncIterator
1+
from typing import AsyncGenerator, AsyncIterator, Optional
22

33
import sentry_sdk
44
from loguru import logger
55

66
from langchain_core.messages.base import BaseMessage as LangchainBaseMessage
77
from langchain_core.prompts import ChatPromptTemplate
8+
from langchain_core.runnables.base import Runnable
89
from langchain.chat_models import init_chat_model
910

1011
from vocode.streaming.action.abstract_factory import AbstractActionFactory
1112
from vocode.streaming.action.default_factory import DefaultActionFactory
1213
from vocode.streaming.agent.anthropic_utils import merge_bot_messages_for_langchain
1314
from vocode.streaming.agent.base_agent import GeneratedResponse, RespondAgent, StreamedResponse
1415
from vocode.streaming.agent.streaming_utils import collate_response_async, stream_response_async
15-
from vocode.streaming.models.actions import FunctionFragment
1616
from vocode.streaming.models.agent import LangchainAgentConfig
1717
from vocode.streaming.models.events import Sender
1818
from vocode.streaming.models.message import BaseMessage, LLMToken
@@ -26,21 +26,26 @@ def __init__(
2626
self,
2727
agent_config: LangchainAgentConfig,
2828
action_factory: AbstractActionFactory = DefaultActionFactory(),
29+
chain: Optional[Runnable] = None,
2930
**kwargs,
3031
):
3132
super().__init__(
3233
agent_config=agent_config,
3334
action_factory=action_factory,
3435
**kwargs,
3536
)
36-
self.model = init_chat_model(model = self.agent_config.model_name, model_provider=self.agent_config.provider, temperature=self.agent_config.temperature, max_tokens=self.agent_config.max_tokens)
37+
self.chain = chain if chain else self.create_chain()
38+
39+
def create_chain(self):
40+
model = init_chat_model(model = self.agent_config.model_name, model_provider=self.agent_config.provider, temperature=self.agent_config.temperature, max_tokens=self.agent_config.max_tokens)
3741
messages_for_prompt_template = [
3842
("placeholder", "{chat_history}")
3943
]
4044
if self.agent_config.prompt_preamble:
4145
messages_for_prompt_template.insert(0, ("system", self.agent_config.prompt_preamble))
42-
self.prompt_template = ChatPromptTemplate.from_messages(messages_for_prompt_template)
43-
self.chain = self.prompt_template | self.model
46+
prompt_template = ChatPromptTemplate.from_messages(messages_for_prompt_template)
47+
chain = prompt_template | model
48+
return chain
4449

4550
async def token_generator(
4651
self,

0 commit comments

Comments
 (0)