Skip to content

Commit

Permalink
'Refactored by Sourcery'
Browse files Browse the repository at this point in the history
  • Loading branch information
Sourcery AI committed Oct 9, 2023
1 parent 3766b24 commit ebb74ed
Show file tree
Hide file tree
Showing 29 changed files with 186 additions and 287 deletions.
14 changes: 5 additions & 9 deletions libs/legacy/app/api/agent_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@

def parse_filter_params(request: Request):
query_params = request.query_params
filter_params = {}

for k, v in query_params.items():
if k.startswith("filter[") and k.endswith("]"):
# Removing 'filter[' from start and ']' from end
filter_key = k[7:-1]
filter_params[filter_key] = v

return filter_params
return {
k[7:-1]: v
for k, v in query_params.items()
if k.startswith("filter[") and k.endswith("]")
}


@router.post(
Expand Down
14 changes: 5 additions & 9 deletions libs/legacy/app/api/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@

def parse_filter_params(request: Request):
query_params = request.query_params
filter_params = {}

for k, v in query_params.items():
if k.startswith("filter[") and k.endswith("]"):
# Removing 'filter[' from start and ']' from end
filter_key = k[7:-1]
filter_params[filter_key] = v

return filter_params
return {
k[7:-1]: v
for k, v in query_params.items()
if k.startswith("filter[") and k.endswith("]")
}


@router.post(
Expand Down
6 changes: 3 additions & 3 deletions libs/legacy/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ async def read_library_agents(token=Depends(JWTBearer())):
)
async def read_agent(agentId: str, token=Depends(JWTBearer())):
"""Agent detail endpoint"""
agent = prisma.agent.find_unique(where={"id": agentId}, include={"prompt": True})

if agent:
if agent := prisma.agent.find_unique(
where={"id": agentId}, include={"prompt": True}
):
return {"success": True, "data": agent}

logger.error(f"Agent with id: {agentId} not found")
Expand Down
40 changes: 17 additions & 23 deletions libs/legacy/app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
@router.post("/auth/sign-in", response_model=SignInOutput)
async def sign_in(signIn: SignIn):
try:
user = prisma.user.find_first(
if user := prisma.user.find_first(
where={
"email": signIn.email,
},
include={"profile": True},
)
if user:
):
validated = validatePassword(signIn.password, user.password)
del user.password

Expand All @@ -34,16 +33,12 @@ async def sign_in(signIn: SignIn):
return {"success": True, "data": SignInOut(token=token, user=user)}

logger.warning("Invalid password")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
else:
logger.warning("User not found")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
except Exception as e:
logger.error("Couldn't find user by email", exc_info=e)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
Expand Down Expand Up @@ -96,16 +91,15 @@ async def oauth_handler(body: OAuth):
)
if user:
return {"success": True, "data": user}
else:
user = prisma.user.create(
{
"email": body.email,
"provider": body.provider,
"name": body.name,
"accessToken": body.access_token,
}
)
prisma.profile.create({"userId": user.id})
user = prisma.user.create(
{
"email": body.email,
"provider": body.provider,
"name": body.name,
"accessToken": body.access_token,
}
)
prisma.profile.create({"userId": user.id})

if user:
return {"success": True, "data": user}
if user:
return {"success": True, "data": user}
5 changes: 3 additions & 2 deletions libs/legacy/app/api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ async def read_user_me(token=Depends(JWTBearer())):
@router.get("/users/{userId}", response_model=UserOutput)
async def read_user(userId: str):
try:
user = prisma.user.find_unique(where={"id": userId}, include={"profile": True})
if user:
if user := prisma.user.find_unique(
where={"id": userId}, include={"profile": True}
):
return {"success": True, "data": user}
else:
logger.error("Couldn't find user")
Expand Down
12 changes: 3 additions & 9 deletions libs/legacy/app/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@ def get_agent(self, session: str | None = None) -> AgentExecutor:
llm = self.agent_base._get_llm()
memory = self.agent_base._get_memory(session)
prompt = self.agent_base._get_prompt()
agent = LLMChain(
return LLMChain(
llm=llm,
memory=memory,
verbose=True,
prompt=prompt,
output_key="output",
)

return agent


class OpenAIAgent(AgentStrategy):
def __init__(self, agent_base: AgentBase):
Expand All @@ -40,7 +38,7 @@ def get_agent(self, session: str | None = None) -> AgentExecutor:
tools = self.agent_base._get_tools()
memory = self.agent_base._get_memory(session)
prompt = self.agent_base._get_prompt()
agent = initialize_agent(
return initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.OPENAI_FUNCTIONS,
Expand All @@ -58,8 +56,6 @@ def get_agent(self, session: str | None = None) -> AgentExecutor:
return_intermediate_steps=True,
)

return agent


class ReactAgent(AgentStrategy):
def __init__(self, agent_base):
Expand All @@ -79,12 +75,10 @@ def get_agent(self, session: str | None = None) -> AgentExecutor:
stop=["\nObservation:"],
allowed_tools=tool_names,
)
agent = AgentExecutor.from_agent_and_tools(
return AgentExecutor.from_agent_and_tools(
agent=agent_config,
tools=tools,
verbose=True,
memory=memory,
return_intermediate_steps=True,
)

return agent
28 changes: 12 additions & 16 deletions libs/legacy/app/lib/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.tools = self._get_agent_tools()

def _get_api_key(self) -> str:
if self.llm["provider"] == "openai-chat" or self.llm["provider"] == "openai":
if self.llm["provider"] in ["openai-chat", "openai"]:
return (
self.llm["api_key"]
if "api_key" in self.llm
Expand Down Expand Up @@ -161,7 +161,7 @@ def _get_llm(self, has_streaming: bool = True) -> Any:
),
],
)
if self.has_streaming and has_streaming != False
if self.has_streaming and has_streaming
else ChatOpenAI(
model_name=self.llm["model"],
openai_api_key=self._get_api_key(),
Expand Down Expand Up @@ -189,7 +189,7 @@ def _get_llm(self, has_streaming: bool = True) -> Any:
)
],
)
if self.has_streaming and has_streaming != False
if self.has_streaming and has_streaming
else ChatAnthropic(anthropic_api_key=self._get_api_key())
)

Expand All @@ -207,8 +207,10 @@ def _get_llm(self, has_streaming: bool = True) -> Any:
)
],
)
if self.has_streaming and has_streaming != False
else Cohere(cohere_api_key=self._get_api_key(), model=self.llm["model"])
if self.has_streaming and has_streaming
else Cohere(
cohere_api_key=self._get_api_key(), model=self.llm["model"]
)
)

if self.llm["provider"] == "azure-openai":
Expand Down Expand Up @@ -286,12 +288,10 @@ def _get_memory(self, session) -> ConversationBufferMemory:
return ConversationBufferMemory(memory_key="chat_history", output_key="output")

def _get_agent_documents(self) -> Any:
agent_documents = prisma.agentdocument.find_many(
return prisma.agentdocument.find_many(
where={"agentId": self.id}, include={"document": True}
)

return agent_documents

def _get_tool_and_input_by_type(
self, type: str, metadata: dict = None
) -> Tuple[Any, Any]:
Expand Down Expand Up @@ -407,12 +407,10 @@ def _get_tools(self) -> list[Tool]:
return tools

def _get_agent_tools(self) -> Any:
tools = prisma.agenttool.find_many(
return prisma.agenttool.find_many(
where={"agentId": self.id}, include={"tool": True}
)

return tools

def _format_trace(self, trace: Any) -> dict:
if self.documents or self.tools:
return json.dumps(
Expand Down Expand Up @@ -463,7 +461,7 @@ def cache_message(

metadata = {
"agentId": agent_id,
"sessionId": str(session_id),
"sessionId": session_id,
"text": query,
"cached_message": ai_message,
"type": "cache",
Expand All @@ -476,13 +474,11 @@ def cache_message(

def get_cached_result(self, query: str) -> str | None:
vectorstore: PineconeVectorStore = VectorStoreBase().get_database()
results = vectorstore.query(
if results := vectorstore.query(
prompt=query,
metadata_filter={"agentId": self.id},
min_score=0.9,
)

if results:
):
timestamp: float = results[0].metadata.get("timestamp", 0.0)

if timestamp and time.time() - timestamp > self.cache_ttl:
Expand Down
7 changes: 1 addition & 6 deletions libs/legacy/app/lib/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@ def create_agent(agent_base: AgentBase):
if agent_base.tools or agent_base.documents:
return OpenAIAgent(agent_base)

return DefaultAgent(agent_base)

elif agent_base.type == "REACT":
if agent_base.tools or agent_base.documents:
return ReactAgent(agent_base)

return DefaultAgent(agent_base)

else:
return DefaultAgent(agent_base)
return DefaultAgent(agent_base)
3 changes: 1 addition & 2 deletions libs/legacy/app/lib/api_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@


def generate_api_token() -> str:
api_token = uuid.uuid4().hex
return api_token
return uuid.uuid4().hex
Loading

0 comments on commit ebb74ed

Please sign in to comment.