Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: refactor code
Browse files Browse the repository at this point in the history
elisalimli committed Mar 12, 2024
1 parent acc65f6 commit c8c076d
Showing 2 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ def __init__(
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
memory_config: MemoryDb = None,
memory_config: Optional[MemoryDb] = None,
):
self.agent_id = agent_id
self.session_id = session_id
34 changes: 23 additions & 11 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import json
import logging
import re
from typing import Any, List
from typing import Any, List, Optional

from decouple import config
from langchain.agents import AgentType, initialize_agent
@@ -26,6 +26,7 @@
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
from app.utils.helpers import get_first_non_null
from app.utils.llm import LLM_MAPPING
from prisma.enums import LLMProvider, MemoryDbProvider
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb

logger = logging.getLogger(__name__)
@@ -152,7 +153,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
**(self.llm_params.dict() if self.llm_params else {}),
}

if llm.provider == "OPENAI":
if llm.provider == LLMProvider.OPENAI:
return ChatOpenAI(
model=LLM_MAPPING[model],
openai_api_key=llm.apiKey,
@@ -161,7 +162,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
**(llm.options if llm.options else {}),
**(llm_params),
)
elif llm.provider == "AZURE_OPENAI":
elif llm.provider == LLMProvider.AZURE_OPENAI:
return AzureChatOpenAI(
api_key=llm.apiKey,
streaming=self.enable_streaming,
@@ -197,15 +198,19 @@ async def _get_prompt(self, agent: Agent) -> str:
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
return SystemMessage(content=content)

async def _get_memory(self, memory_db: MemoryDb) -> List:
async def _get_memory(self, memory_db: Optional[MemoryDb]) -> List:
logger.debug(f"Use memory config: {memory_db}")
if memory_db is None:
memory_provider = config("MEMORY")
memory_provider = config("MEMORY", "motorhead")
options = {}
else:
memory_provider = memory_db.provider
options = memory_db.options
if memory_provider == "REDIS" or memory_provider == "redis":

memory_provider = memory_provider.upper()
logger.info(f"Using memory provider: {memory_provider}")

if memory_provider == MemoryDbProvider.REDIS:
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
@@ -227,18 +232,25 @@ async def _get_memory(self, memory_db: MemoryDb) -> List:
config("REDIS_MEMORY_WINDOW", 10),
),
)
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
elif memory_provider == MemoryDbProvider.MOTORHEAD:
url = get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
)

if not url:
raise ValueError(
"Memory API URL is required for Motorhead memory provider"
)

memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
),
url=url,
return_messages=True,
output_key="output",
)

0 comments on commit c8c076d

Please sign in to comment.